pranamanam
commited on
Upload 15 files
Browse files- README.md +172 -3
- configs/config.py +19 -0
- data/test.csv +0 -0
- data/train.csv +0 -0
- data/val.csv +0 -0
- models/diffusion.py +88 -0
- requirements.txt +13 -0
- scripts/generate.py +131 -0
- scripts/test.py +17 -0
- scripts/train.py +24 -0
- test.csv +0 -0
- train.csv +0 -0
- utils/data_loader.py +30 -0
- utils/esm_utils.py +13 -0
- val.csv +0 -0
README.md
CHANGED
@@ -1,3 +1,172 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Latent Diffusion Model for Protein Sequence Generation using MDLM and ESM-2-650M
|
3 |
+
|
4 |
+
Here, we implement a masked discrete latent diffusion model for generating protein sequences. The model leverages the MDLM framework and ESM-2-650M for latent space representation and diffusion.
|
5 |
+
|
6 |
+
## Directory Structure
|
7 |
+
|
8 |
+
```
|
9 |
+
project/
|
10 |
+
│
|
11 |
+
├── configs/
|
12 |
+
│ ├── config.py
|
13 |
+
│
|
14 |
+
├── data/
|
15 |
+
│ ├── train.csv
|
16 |
+
│ ├── val.csv
|
17 |
+
│ ├── test.csv
|
18 |
+
│
|
19 |
+
├── models/
|
20 |
+
│ ├── diffusion.py
|
21 |
+
│
|
22 |
+
├── scripts/
|
23 |
+
│ ├── train.py
|
24 |
+
│ ├── test.py
|
25 |
+
│ ├── generate.py
|
26 |
+
│
|
27 |
+
├── utils/
|
28 |
+
│ ├── data_loader.py
|
29 |
+
│ ├── esm_utils.py
|
30 |
+
│
|
31 |
+
├── checkpoints/
|
32 |
+
│ ├── example.ckpt # Placeholder for checkpoints
|
33 |
+
│
|
34 |
+
├── requirements.txt
|
35 |
+
│
|
36 |
+
└── README.md
|
37 |
+
```
|
38 |
+
|
39 |
+
## Setup and Requirements
|
40 |
+
|
41 |
+
### Prerequisites
|
42 |
+
|
43 |
+
- Python 3.8+
|
44 |
+
- CUDA (for GPU support)
|
45 |
+
|
46 |
+
### Install Dependencies
|
47 |
+
|
48 |
+
1. Create and activate a virtual environment:
|
49 |
+
```bash
|
50 |
+
python -m venv venv
|
51 |
+
source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
52 |
+
```
|
53 |
+
|
54 |
+
2. Install the required packages:
|
55 |
+
```bash
|
56 |
+
pip install -r requirements.txt
|
57 |
+
```
|
58 |
+
|
59 |
+
### Prepare Data
|
60 |
+
|
61 |
+
Place your data files (`train.csv`, `val.csv`, `test.csv`) in the `data/` directory. Ensure that these CSV files contain a column named `sequence` with the protein sequences.
|
62 |
+
|
63 |
+
## Configuration
|
64 |
+
|
65 |
+
Modify the `configs/config.py` file to set your hyperparameters, model configurations, and data paths. Here is an example configuration:
|
66 |
+
|
67 |
+
```python
|
68 |
+
class Config:
|
69 |
+
model_name = "facebook/esm2_t33_650M_UR50D"
|
70 |
+
latent_dim = 1280 # Adjust based on ESM-2 latent dimension
|
71 |
+
optim = {"lr": 1e-4}
|
72 |
+
training = {
|
73 |
+
"ema": 0.999,
|
74 |
+
"epochs": 10,
|
75 |
+
"batch_size": 32,
|
76 |
+
"gpus": 8,
|
77 |
+
"precision": 16, # Mixed precision training
|
78 |
+
"accumulate_grad_batches": 2, # Gradient accumulation
|
79 |
+
"save_dir": "./checkpoints/",
|
80 |
+
}
|
81 |
+
data_path = "./data/"
|
82 |
+
T = 1000 # Number of diffusion steps
|
83 |
+
subs_masking = False
|
84 |
+
```
|
85 |
+
|
86 |
+
## Mathematical Formulations
|
87 |
+
|
88 |
+
### Forward Diffusion
|
89 |
+
|
90 |
+
The forward diffusion process adds noise to the latent representations of the protein sequences:
|
91 |
+
\[ ext{noisy\_latents} = ext{latents} + \sigma \cdot \epsilon \]
|
92 |
+
where:
|
93 |
+
- \(\sigma\) is the noise level.
|
94 |
+
- \(\epsilon \sim \mathcal{N}(0, 1)\) is Gaussian noise.
|
95 |
+
|
96 |
+
### Reverse Diffusion
|
97 |
+
|
98 |
+
The reverse diffusion process denoises the latent representations:
|
99 |
+
\[ ext{denoised\_latents} = ext{backbone}( ext{noisy\_latents}, \sigma) \]
|
100 |
+
where the backbone model predicts the denoised latent representations.
|
101 |
+
|
102 |
+
### Loss Function
|
103 |
+
|
104 |
+
The loss function used to train the model is the Mean Squared Error (MSE) between the denoised latents and the original latents:
|
105 |
+
\[ \mathcal{L} = ext{MSE}( ext{denoised\_latents}, ext{latents}) \]
|
106 |
+
|
107 |
+
## Training
|
108 |
+
|
109 |
+
To train the model, run the `train.py` script:
|
110 |
+
|
111 |
+
```bash
|
112 |
+
python scripts/train.py
|
113 |
+
```
|
114 |
+
|
115 |
+
This script will:
|
116 |
+
- Load the ESM-2-650M model and tokenizer from Hugging Face.
|
117 |
+
- Prepare the data loaders for training and validation datasets.
|
118 |
+
- Initialize the latent diffusion model.
|
119 |
+
- Train the model using the specified configurations.
|
120 |
+
|
121 |
+
## Testing
|
122 |
+
|
123 |
+
To test the model, run the `test.py` script:
|
124 |
+
|
125 |
+
```bash
|
126 |
+
python scripts/test.py
|
127 |
+
```
|
128 |
+
|
129 |
+
This script will:
|
130 |
+
- Load the trained model from the checkpoint.
|
131 |
+
- Prepare the data loader for the test dataset.
|
132 |
+
- Evaluate the model on the test dataset.
|
133 |
+
|
134 |
+
## Generating Protein Sequences
|
135 |
+
|
136 |
+
To generate protein sequences, use the `generate.py` script. This script supports three strategies:
|
137 |
+
|
138 |
+
1. **Generating a Scaffold to Connect Multiple Peptides**:
|
139 |
+
```bash
|
140 |
+
python scripts/generate.py scaffold <peptide1> <peptide2> ... <final_length>
|
141 |
+
```
|
142 |
+
Example:
|
143 |
+
```bash
|
144 |
+
python scripts/generate.py scaffold MKTAYIAKQRQ GLIEVQ 30
|
145 |
+
```
|
146 |
+
|
147 |
+
2. **Filling in Specified Regions in a Given Protein Sequence**:
|
148 |
+
```bash
|
149 |
+
python scripts/generate.py fill <sequence_with_X>
|
150 |
+
```
|
151 |
+
Example:
|
152 |
+
```bash
|
153 |
+
python scripts/generate.py fill MKTAYIAKXXXXXXXLEERLGLIEVQ
|
154 |
+
```
|
155 |
+
|
156 |
+
3. **Purely De Novo Generation of a Protein Sequence**:
|
157 |
+
```bash
|
158 |
+
python scripts/generate.py de_novo <sequence_length>
|
159 |
+
```
|
160 |
+
Example:
|
161 |
+
```bash
|
162 |
+
python scripts/generate.py de_novo 50
|
163 |
+
```
|
164 |
+
|
165 |
+
## Notes
|
166 |
+
|
167 |
+
- Ensure you have a compatible CUDA environment if you are training on GPUs.
|
168 |
+
- Modify the paths and configurations in `configs/config.py` as needed to match your setup.
|
169 |
+
|
170 |
+
## Acknowledgements
|
171 |
+
|
172 |
+
This implementation is based on the MDLM framework and uses the ESM-2-650M model.
|
configs/config.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### configs/config.py
|
2 |
+
|
3 |
+
```python
|
4 |
+
class Config:
|
5 |
+
model_name = "facebook/esm2_t33_650M_UR50D"
|
6 |
+
latent_dim = 1280 # Adjust based on ESM-2 latent dimension
|
7 |
+
optim = {"lr": 1e-4}
|
8 |
+
training = {
|
9 |
+
"ema": 0.999,
|
10 |
+
"epochs": 10,
|
11 |
+
"batch_size": 32,
|
12 |
+
"gpus": 8,
|
13 |
+
"precision": 16, # Mixed precision training
|
14 |
+
"accumulate_grad_batches": 2, # Gradient accumulation
|
15 |
+
"save_dir": "./checkpoints/",
|
16 |
+
}
|
17 |
+
data_path = "./data/"
|
18 |
+
T = 1000 # Number of diffusion steps
|
19 |
+
subs_masking = False
|
data/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/val.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/diffusion.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import pytorch_lightning as L
|
6 |
+
import torchmetrics
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from models import dit, ema
|
9 |
+
import noise_schedule # Assuming this is part of the MDLM repository
|
10 |
+
|
11 |
+
LOG2 = math.log(2)
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class Loss:
|
15 |
+
loss: torch.FloatTensor
|
16 |
+
nlls: torch.FloatTensor
|
17 |
+
token_mask: torch.FloatTensor
|
18 |
+
|
19 |
+
class NLL(torchmetrics.MeanMetric):
|
20 |
+
pass
|
21 |
+
|
22 |
+
class BPD(NLL):
|
23 |
+
def compute(self) -> torch.Tensor:
|
24 |
+
"""Computes the bits per dimension.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
bpd
|
28 |
+
"""
|
29 |
+
return self.mean_value / self.weight / LOG2
|
30 |
+
|
31 |
+
class Perplexity(NLL):
|
32 |
+
def compute(self) -> torch.Tensor:
|
33 |
+
"""Computes the Perplexity.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Perplexity
|
37 |
+
"""
|
38 |
+
return torch.exp(self.mean_value / self.weight)
|
39 |
+
|
40 |
+
class Diffusion(L.LightningModule):
|
41 |
+
def __init__(self, config, latent_dim):
|
42 |
+
super().__init__()
|
43 |
+
self.config = config
|
44 |
+
self.latent_dim = latent_dim
|
45 |
+
|
46 |
+
self.backbone = dit.DIT(config, vocab_size=self.latent_dim)
|
47 |
+
self.T = self.config.T
|
48 |
+
self.subs_masking = self.config.subs_masking
|
49 |
+
|
50 |
+
self.softplus = torch.nn.Softplus()
|
51 |
+
metrics = torchmetrics.MetricCollection({
|
52 |
+
'nll': NLL(),
|
53 |
+
'bpd': BPD(),
|
54 |
+
'ppl': Perplexity(),
|
55 |
+
})
|
56 |
+
metrics.set_dtype(torch.float64)
|
57 |
+
self.train_metrics = metrics.clone(prefix='train/')
|
58 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
59 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
60 |
+
|
61 |
+
self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
|
62 |
+
self.lr = self.config.optim["lr"]
|
63 |
+
self.sampling_eps = self.config.training.get("sampling_eps", 1e-5)
|
64 |
+
self.time_conditioning = self.config.get("time_conditioning", True)
|
65 |
+
self.neg_infinity = -1000000.0
|
66 |
+
|
67 |
+
def forward(self, latents, sigma):
|
68 |
+
"""Forward diffusion process, adds noise to the latents."""
|
69 |
+
noise = sigma * torch.randn_like(latents)
|
70 |
+
noisy_latents = latents + noise
|
71 |
+
return noisy_latents
|
72 |
+
|
73 |
+
def reverse_diffusion(self, noisy_latents, sigma):
|
74 |
+
"""Reverse diffusion process, denoises the latents."""
|
75 |
+
denoised_latents = self.backbone(noisy_latents, sigma)
|
76 |
+
return denoised_latents
|
77 |
+
|
78 |
+
def training_step(self, batch, batch_idx):
|
79 |
+
sigma = torch.rand(batch.size(0), device=self.device)
|
80 |
+
noisy_latents = self.forward(batch, sigma)
|
81 |
+
denoised_latents = self.reverse_diffusion(noisy_latents, sigma)
|
82 |
+
loss = F.mse_loss(denoised_latents, batch)
|
83 |
+
self.log("train_loss", loss)
|
84 |
+
return loss
|
85 |
+
|
86 |
+
def configure_optimizers(self):
|
87 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
88 |
+
return optimizer
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.0
|
2 |
+
torchvision==0.11.1
|
3 |
+
torchaudio==0.10.0
|
4 |
+
pytorch-lightning==1.5.10
|
5 |
+
transformers==4.12.3
|
6 |
+
pandas==1.3.4
|
7 |
+
numpy==1.21.4
|
8 |
+
scipy==1.7.3
|
9 |
+
scikit-learn==1.0.1
|
10 |
+
tqdm==4.62.3
|
11 |
+
omegaconf==2.1.1
|
12 |
+
hydra-core==1.1.1
|
13 |
+
torchmetrics==0.6.2
|
scripts/generate.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from models.diffusion import Diffusion
|
5 |
+
from configs.config import Config
|
6 |
+
from utils.esm_utils import load_esm2_model, get_latents
|
7 |
+
|
8 |
+
def mask_sequence(sequence, mask_char='X'):
|
9 |
+
"""Masks parts of the sequence based on the mask_char."""
|
10 |
+
mask_indices = [i for i, char in enumerate(sequence) if char == mask_char]
|
11 |
+
masked_sequence = sequence.replace(mask_char, '[MASK]')
|
12 |
+
return masked_sequence, mask_indices
|
13 |
+
|
14 |
+
def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices):
|
15 |
+
"""Generates the filled sequence for the masked regions."""
|
16 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt")
|
17 |
+
with torch.no_grad():
|
18 |
+
outputs = esm_model(**inputs)
|
19 |
+
latents = outputs.last_hidden_state.squeeze(0)
|
20 |
+
|
21 |
+
sigma = torch.rand(1, device=latents.device)
|
22 |
+
noisy_latents = model.forward(latents, sigma)
|
23 |
+
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
|
24 |
+
|
25 |
+
filled_sequence = list(masked_sequence)
|
26 |
+
for idx in mask_indices:
|
27 |
+
token_id = torch.argmax(denoised_latents[idx]).item()
|
28 |
+
filled_sequence[idx] = tokenizer.decode([token_id])
|
29 |
+
|
30 |
+
return ''.join(filled_sequence)
|
31 |
+
|
32 |
+
def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length):
|
33 |
+
"""Generates a scaffold sequence to connect multiple peptides."""
|
34 |
+
total_peptide_length = sum(len(peptide) for peptide in peptides)
|
35 |
+
scaffold_length = final_length - total_peptide_length
|
36 |
+
if scaffold_length <= 0:
|
37 |
+
raise ValueError("Final length must be greater than the combined length of the peptides.")
|
38 |
+
|
39 |
+
scaffold = "[MASK]" * scaffold_length
|
40 |
+
masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:])
|
41 |
+
|
42 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt")
|
43 |
+
with torch.no_grad():
|
44 |
+
outputs = esm_model(**inputs)
|
45 |
+
latents = outputs.last_hidden_state.squeeze(0)
|
46 |
+
|
47 |
+
sigma = torch.rand(1, device=latents.device)
|
48 |
+
noisy_latents = model.forward(latents, sigma)
|
49 |
+
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
|
50 |
+
|
51 |
+
filled_sequence = list(masked_sequence)
|
52 |
+
scaffold_start = len(peptides[0])
|
53 |
+
scaffold_end = scaffold_start + scaffold_length
|
54 |
+
for idx in range(scaffold_start, scaffold_end):
|
55 |
+
token_id = torch.argmax(denoised_latents[idx]).item()
|
56 |
+
filled_sequence[idx] = tokenizer.decode([token_id])
|
57 |
+
|
58 |
+
return ''.join(filled_sequence)
|
59 |
+
|
60 |
+
def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length):
|
61 |
+
"""Generates a de novo protein sequence of the specified length."""
|
62 |
+
scaffold = "[MASK]" * sequence_length
|
63 |
+
masked_sequence = scaffold
|
64 |
+
|
65 |
+
inputs = tokenizer(masked_sequence, return_tensors="pt")
|
66 |
+
with torch.no_grad():
|
67 |
+
outputs = esm_model(**inputs)
|
68 |
+
latents = outputs.last_hidden_state.squeeze(0)
|
69 |
+
|
70 |
+
sigma = torch.rand(1, device=latents.device)
|
71 |
+
noisy_latents = model.forward(latents, sigma)
|
72 |
+
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
|
73 |
+
|
74 |
+
filled_sequence = list(masked_sequence)
|
75 |
+
for idx in range(sequence_length):
|
76 |
+
token_id = torch.argmax(denoised_latents[idx]).item()
|
77 |
+
filled_sequence[idx] = tokenizer.decode([token_id])
|
78 |
+
|
79 |
+
return ''.join(filled_sequence)
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
import argparse
|
83 |
+
|
84 |
+
# Argument parsing
|
85 |
+
parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.")
|
86 |
+
subparsers = parser.add_subparsers(dest="mode")
|
87 |
+
|
88 |
+
# Subparser for the first strategy (multiple peptides to scaffold)
|
89 |
+
parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.")
|
90 |
+
parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.")
|
91 |
+
parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.")
|
92 |
+
|
93 |
+
# Subparser for the second strategy (fill in regions)
|
94 |
+
parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.")
|
95 |
+
parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.")
|
96 |
+
|
97 |
+
# Subparser for the third strategy (de novo generation)
|
98 |
+
parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.")
|
99 |
+
parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.")
|
100 |
+
|
101 |
+
args = parser.parse_args()
|
102 |
+
|
103 |
+
# Load configurations
|
104 |
+
config = Config()
|
105 |
+
|
106 |
+
# Load models
|
107 |
+
tokenizer, esm_model = load_esm2_model(config.model_name)
|
108 |
+
diffusion_model = Diffusion.load_from_checkpoint(config.training["save_dir"] + "example.ckpt", config=config, latent_dim=config.latent_dim)
|
109 |
+
diffusion_model.eval()
|
110 |
+
|
111 |
+
if args.mode == "scaffold":
|
112 |
+
peptides = args.peptides
|
113 |
+
final_length = args.final_length
|
114 |
+
filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length)
|
115 |
+
print(f"Peptides: {' '.join(peptides)}")
|
116 |
+
print(f"Final Length: {final_length}")
|
117 |
+
print(f"Generated Protein: {filled_sequence}")
|
118 |
+
|
119 |
+
elif args.mode == "fill":
|
120 |
+
sequence = args.sequence
|
121 |
+
masked_sequence, mask_indices = mask_sequence(sequence)
|
122 |
+
filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices)
|
123 |
+
print(f"Original Sequence: {sequence}")
|
124 |
+
print(f"Masked Sequence: {masked_sequence}")
|
125 |
+
print(f"Filled Sequence: {filled_sequence}")
|
126 |
+
|
127 |
+
elif args.mode == "de_novo":
|
128 |
+
sequence_length = args.sequence_length
|
129 |
+
filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length)
|
130 |
+
print(f"De Novo Sequence Length: {sequence_length}")
|
131 |
+
print(f"Generated Protein: {filled_sequence}")
|
scripts/test.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as L
|
2 |
+
from configs.config import Config
|
3 |
+
from utils.data_loader import get_dataloaders
|
4 |
+
from models.diffusion import Diffusion
|
5 |
+
|
6 |
+
# Get dataloaders
|
7 |
+
_, _, test_loader = get_dataloaders(Config)
|
8 |
+
|
9 |
+
# Initialize model
|
10 |
+
checkpoint_path = Config.training["save_dir"] + "example.ckpt"
|
11 |
+
latent_diffusion_model = Diffusion.load_from_checkpoint(checkpoint_path, config=Config, latent_dim=Config.latent_dim)
|
12 |
+
|
13 |
+
# Initialize trainer
|
14 |
+
trainer = L.Trainer(gpus=Config.training["gpus"], precision=Config.training["precision"])
|
15 |
+
|
16 |
+
# Test the model
|
17 |
+
trainer.test(latent_diffusion_model, test_loader)
|
scripts/train.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as L
|
2 |
+
from pytorch_lightning.strategies import DDPStrategy
|
3 |
+
from configs.config import Config
|
4 |
+
from utils.data_loader import get_dataloaders
|
5 |
+
from models.diffusion import Diffusion
|
6 |
+
|
7 |
+
# Get dataloaders
|
8 |
+
train_loader, val_loader, _ = get_dataloaders(Config)
|
9 |
+
|
10 |
+
# Initialize model
|
11 |
+
latent_diffusion_model = Diffusion(Config, latent_dim=Config.latent_dim)
|
12 |
+
|
13 |
+
# Initialize trainer
|
14 |
+
trainer = L.Trainer(
|
15 |
+
max_epochs=Config.training["epochs"],
|
16 |
+
gpus=Config.training["gpus"],
|
17 |
+
precision=Config.training["precision"],
|
18 |
+
strategy=DDPStrategy(find_unused_parameters=False),
|
19 |
+
accumulate_grad_batches=Config.training["accumulate_grad_batches"],
|
20 |
+
default_root_dir=Config.training["save_dir"]
|
21 |
+
)
|
22 |
+
|
23 |
+
# Train the model
|
24 |
+
trainer.fit(latent_diffusion_model, train_loader, val_loader)
|
test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils/data_loader.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from utils.esm_utils import get_latents, load_esm2_model
|
4 |
+
|
5 |
+
class ProteinDataset(Dataset):
|
6 |
+
def __init__(self, csv_file, tokenizer, model):
|
7 |
+
self.data = pd.read_csv(csv_file)
|
8 |
+
self.tokenizer = tokenizer
|
9 |
+
self.model = model
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
return len(self.data)
|
13 |
+
|
14 |
+
def __getitem__(self, idx):
|
15 |
+
sequence = self.data.iloc[idx]['sequence']
|
16 |
+
latents = get_latents(self.model, self.tokenizer, sequence)
|
17 |
+
return latents
|
18 |
+
|
19 |
+
def get_dataloaders(config):
|
20 |
+
tokenizer, model = load_esm2_model(config.model_name)
|
21 |
+
|
22 |
+
train_dataset = ProteinDataset(config.data_path + "train.csv", tokenizer, model)
|
23 |
+
val_dataset = ProteinDataset(config.data_path + "val.csv", tokenizer, model)
|
24 |
+
test_dataset = ProteinDataset(config.data_path + "test.csv", tokenizer, model)
|
25 |
+
|
26 |
+
train_loader = DataLoader(train_dataset, batch_size=config.training["batch_size"], shuffle=True)
|
27 |
+
val_loader = DataLoader(val_dataset, batch_size=config.training["batch_size"], shuffle=False)
|
28 |
+
test_loader = DataLoader(test_dataset, batch_size=config.training["batch_size"], shuffle=False)
|
29 |
+
|
30 |
+
return train_loader, val_loader, test_loader
|
utils/esm_utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
|
4 |
+
def load_esm2_model(model_name):
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
6 |
+
model = AutoModel.from_pretrained(model_name)
|
7 |
+
return tokenizer, model
|
8 |
+
|
9 |
+
def get_latents(model, tokenizer, sequence):
|
10 |
+
inputs = tokenizer(sequence, return_tensors="pt")
|
11 |
+
with torch.no_grad():
|
12 |
+
outputs = model(**inputs)
|
13 |
+
return outputs.last_hidden_state.squeeze(0)
|
val.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|