File size: 5,823 Bytes
fedcb95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from diffusion import Diffusion
import config
from esm_utils import load_esm2_model, get_latents

def mask_sequence(sequence, mask_char='X'):
    """Masks parts of the sequence based on the mask_char."""
    mask_indices = [i for i, char in enumerate(sequence) if char == mask_char]
    masked_sequence = sequence.replace(mask_char, '[MASK]')
    return masked_sequence, mask_indices

def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices):
    """Generates the filled sequence for the masked regions."""
    inputs = tokenizer(masked_sequence, return_tensors="pt")
    with torch.no_grad():
        outputs = esm_model(**inputs)
    latents = outputs.last_hidden_state.squeeze(0)
    
    sigma = torch.rand(1, device=latents.device)
    noisy_latents = model.forward(latents, sigma)
    denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
    
    filled_sequence = list(masked_sequence)
    for idx in mask_indices:
        token_id = torch.argmax(denoised_latents[idx]).item()
        filled_sequence[idx] = tokenizer.decode([token_id])
    
    return ''.join(filled_sequence)

def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length):
    """Generates a scaffold sequence to connect multiple peptides."""
    total_peptide_length = sum(len(peptide) for peptide in peptides)
    scaffold_length = final_length - total_peptide_length
    if scaffold_length <= 0:
        raise ValueError("Final length must be greater than the combined length of the peptides.")
    
    scaffold = "[MASK]" * scaffold_length
    masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:])
    
    inputs = tokenizer(masked_sequence, return_tensors="pt")
    with torch.no_grad():
        outputs = esm_model(**inputs)
    latents = outputs.last_hidden_state.squeeze(0)
    
    sigma = torch.rand(1, device=latents.device)
    noisy_latents = model.forward(latents, sigma)
    denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
    
    filled_sequence = list(masked_sequence)
    scaffold_start = len(peptides[0])
    scaffold_end = scaffold_start + scaffold_length
    for idx in range(scaffold_start, scaffold_end):
        token_id = torch.argmax(denoised_latents[idx]).item()
        filled_sequence[idx] = tokenizer.decode([token_id])
    
    return ''.join(filled_sequence)

def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length):
    """Generates a de novo protein sequence of the specified length."""
    scaffold = "[MASK]" * sequence_length
    masked_sequence = scaffold
    
    inputs = tokenizer(masked_sequence, return_tensors="pt")
    with torch.no_grad():
        outputs = esm_model(**inputs)
    latents = outputs.last_hidden_state.squeeze(0)
    
    sigma = torch.rand(1, device=latents.device)
    noisy_latents = model.forward(latents, sigma)
    denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
    
    filled_sequence = list(masked_sequence)
    for idx in range(sequence_length):
        token_id = torch.argmax(denoised_latents[idx]).item()
        filled_sequence[idx] = tokenizer.decode([token_id])
    
    return ''.join(filled_sequence)

if __name__ == "__main__":
    import argparse

    # Argument parsing
    parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.")
    subparsers = parser.add_subparsers(dest="mode")

    # Subparser for the first strategy (multiple peptides to scaffold)
    parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.")
    parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.")
    parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.")

    # Subparser for the second strategy (fill in regions)
    parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.")
    parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.")

    # Subparser for the third strategy (de novo generation)
    parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.")
    parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.")

    args = parser.parse_args()

    # Load models
    tokenizer, esm_model = load_esm2_model(config.MODEL_NAME)
    diffusion_model = Diffusion.load_from_checkpoint(config.Training.SAVE_DIR + "best_model.ckpt", config=config, latent_dim=config.LATENT_DIM)
    diffusion_model.eval()

    if args.mode == "scaffold":
        peptides = args.peptides
        final_length = args.final_length
        filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length)
        print(f"Peptides:          {' '.join(peptides)}")
        print(f"Final Length:      {final_length}")
        print(f"Generated Protein: {filled_sequence}")

    elif args.mode == "fill":
        sequence = args.sequence
        masked_sequence, mask_indices = mask_sequence(sequence)
        filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices)
        print(f"Original Sequence: {sequence}")
        print(f"Masked Sequence:   {masked_sequence}")
        print(f"Filled Sequence:   {filled_sequence}")

    elif args.mode == "de_novo":
        sequence_length = args.sequence_length
        filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length)
        print(f"De Novo Sequence Length: {sequence_length}")
        print(f"Generated Protein: {filled_sequence}")