MeMDLM / scripts /generate.py
pranamanam's picture
Upload 15 files
ed920f9 verified
raw
history blame
5.9 kB
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from models.diffusion import Diffusion
from configs.config import Config
from utils.esm_utils import load_esm2_model, get_latents
def mask_sequence(sequence, mask_char='X'):
"""Masks parts of the sequence based on the mask_char."""
mask_indices = [i for i, char in enumerate(sequence) if char == mask_char]
masked_sequence = sequence.replace(mask_char, '[MASK]')
return masked_sequence, mask_indices
def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices):
"""Generates the filled sequence for the masked regions."""
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
for idx in mask_indices:
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length):
"""Generates a scaffold sequence to connect multiple peptides."""
total_peptide_length = sum(len(peptide) for peptide in peptides)
scaffold_length = final_length - total_peptide_length
if scaffold_length <= 0:
raise ValueError("Final length must be greater than the combined length of the peptides.")
scaffold = "[MASK]" * scaffold_length
masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:])
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
scaffold_start = len(peptides[0])
scaffold_end = scaffold_start + scaffold_length
for idx in range(scaffold_start, scaffold_end):
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length):
"""Generates a de novo protein sequence of the specified length."""
scaffold = "[MASK]" * sequence_length
masked_sequence = scaffold
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs)
latents = outputs.last_hidden_state.squeeze(0)
sigma = torch.rand(1, device=latents.device)
noisy_latents = model.forward(latents, sigma)
denoised_latents = model.reverse_diffusion(noisy_latents, sigma)
filled_sequence = list(masked_sequence)
for idx in range(sequence_length):
token_id = torch.argmax(denoised_latents[idx]).item()
filled_sequence[idx] = tokenizer.decode([token_id])
return ''.join(filled_sequence)
if __name__ == "__main__":
import argparse
# Argument parsing
parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.")
subparsers = parser.add_subparsers(dest="mode")
# Subparser for the first strategy (multiple peptides to scaffold)
parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.")
parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.")
parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.")
# Subparser for the second strategy (fill in regions)
parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.")
parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.")
# Subparser for the third strategy (de novo generation)
parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.")
parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.")
args = parser.parse_args()
# Load configurations
config = Config()
# Load models
tokenizer, esm_model = load_esm2_model(config.model_name)
diffusion_model = Diffusion.load_from_checkpoint(config.training["save_dir"] + "example.ckpt", config=config, latent_dim=config.latent_dim)
diffusion_model.eval()
if args.mode == "scaffold":
peptides = args.peptides
final_length = args.final_length
filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length)
print(f"Peptides: {' '.join(peptides)}")
print(f"Final Length: {final_length}")
print(f"Generated Protein: {filled_sequence}")
elif args.mode == "fill":
sequence = args.sequence
masked_sequence, mask_indices = mask_sequence(sequence)
filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices)
print(f"Original Sequence: {sequence}")
print(f"Masked Sequence: {masked_sequence}")
print(f"Filled Sequence: {filled_sequence}")
elif args.mode == "de_novo":
sequence_length = args.sequence_length
filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length)
print(f"De Novo Sequence Length: {sequence_length}")
print(f"Generated Protein: {filled_sequence}")