MeMDLM / scripts /generate.py
sgoel30's picture
Upload 4 files
fedcb95 verified
raw
history blame
5.82 kB
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from diffusion import Diffusion
import config
from esm_utils import load_esm2_model, get_latents
def mask_sequence(sequence, mask_char='X'):
"""Masks parts of the sequence based on the mask_char."""
mask_indices = [i for i, char in enumerate(sequence) if char == mask_char]
masked_sequence = sequence.replace(mask_char, '[MASK]')
return masked_sequence, mask_indices
def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices):
"""Generates the filled sequence for the masked regions."""
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
for idx in mask_indices:
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length):
"""Generates a scaffold sequence to connect multiple peptides."""
total_peptide_length = sum(len(peptide) for peptide in peptides)
scaffold_length = final_length - total_peptide_length
if scaffold_length <= 0:
raise ValueError("Final length must be greater than the combined length of the peptides.")
scaffold = "[MASK]" * scaffold_length
masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:])
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
scaffold_start = len(peptides[0])
scaffold_end = scaffold_start + scaffold_length
for idx in range(scaffold_start, scaffold_end):
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length):
"""Generates a de novo protein sequence of the specified length."""
scaffold = "[MASK]" * sequence_length
masked_sequence = scaffold
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
for idx in range(sequence_length):
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
if __name__ == "__main__":
import argparse
# Argument parsing
parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.")
subparsers = parser.add_subparsers(dest="mode")
# Subparser for the first strategy (multiple peptides to scaffold)
parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.")
parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.")
parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.")
# Subparser for the second strategy (fill in regions)
parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.")
parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.")
# Subparser for the third strategy (de novo generation)
parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.")
parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.")
args = parser.parse_args()
# Load models
tokenizer, esm_model = load_esm2_model(config.MODEL_NAME)
diffusion_model = Diffusion.load_from_checkpoint(config.Training.SAVE_DIR + "best_model.ckpt", config=config, latent_dim=config.LATENT_DIM)
diffusion_model.eval()
if args.mode == "scaffold":
peptides = args.peptides
final_length = args.final_length
filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length)
print(f"Peptides: {' '.join(peptides)}")
print(f"Final Length: {final_length}")
print(f"Generated Protein: {filled_sequence}")
elif args.mode == "fill":
sequence = args.sequence
masked_sequence, mask_indices = mask_sequence(sequence)
filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices)
print(f"Original Sequence: {sequence}")
print(f"Masked Sequence: {masked_sequence}")
print(f"Filled Sequence: {filled_sequence}")
elif args.mode == "de_novo":
sequence_length = args.sequence_length
filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length)
print(f"De Novo Sequence Length: {sequence_length}")
print(f"Generated Protein: {filled_sequence}")