File size: 5,823 Bytes
fedcb95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from diffusion import Diffusion
import config
from esm_utils import load_esm2_model, get_latents
def mask_sequence(sequence, mask_char='X'):
"""Masks parts of the sequence based on the mask_char."""
mask_indices = [i for i, char in enumerate(sequence) if char == mask_char]
masked_sequence = sequence.replace(mask_char, '[MASK]')
return masked_sequence, mask_indices
def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices):
"""Generates the filled sequence for the masked regions."""
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
for idx in mask_indices:
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length):
"""Generates a scaffold sequence to connect multiple peptides."""
total_peptide_length = sum(len(peptide) for peptide in peptides)
scaffold_length = final_length - total_peptide_length
if scaffold_length <= 0:
raise ValueError("Final length must be greater than the combined length of the peptides.")
scaffold = "[MASK]" * scaffold_length
masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:])
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
scaffold_start = len(peptides[0])
scaffold_end = scaffold_start + scaffold_length
for idx in range(scaffold_start, scaffold_end):
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length):
"""Generates a de novo protein sequence of the specified length."""
scaffold = "[MASK]" * sequence_length
masked_sequence = scaffold
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
for idx in range(sequence_length):
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
if __name__ == "__main__":
import argparse
# Argument parsing
parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.")
subparsers = parser.add_subparsers(dest="mode")
# Subparser for the first strategy (multiple peptides to scaffold)
parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.")
parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.")
parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.")
# Subparser for the second strategy (fill in regions)
parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.")
parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.")
# Subparser for the third strategy (de novo generation)
parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.")
parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.")
args = parser.parse_args()
# Load models
tokenizer, esm_model = load_esm2_model(config.MODEL_NAME)
diffusion_model = Diffusion.load_from_checkpoint(config.Training.SAVE_DIR + "best_model.ckpt", config=config, latent_dim=config.LATENT_DIM)
diffusion_model.eval()
if args.mode == "scaffold":
peptides = args.peptides
final_length = args.final_length
filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length)
print(f"Peptides: {' '.join(peptides)}")
print(f"Final Length: {final_length}")
print(f"Generated Protein: {filled_sequence}")
elif args.mode == "fill":
sequence = args.sequence
masked_sequence, mask_indices = mask_sequence(sequence)
filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices)
print(f"Original Sequence: {sequence}")
print(f"Masked Sequence: {masked_sequence}")
print(f"Filled Sequence: {filled_sequence}")
elif args.mode == "de_novo":
sequence_length = args.sequence_length
filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length)
print(f"De Novo Sequence Length: {sequence_length}")
print(f"Generated Protein: {filled_sequence}")
|