import torch import numpy as np from transformers import AutoTokenizer, AutoModel from models.diffusion import Diffusion from configs.config import Config from utils.esm_utils import load_esm2_model, get_latents def mask_sequence(sequence, mask_char='X'): """Masks parts of the sequence based on the mask_char.""" mask_indices = [i for i, char in enumerate(sequence) if char == mask_char] masked_sequence = sequence.replace(mask_char, '[MASK]') return masked_sequence, mask_indices def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices): """Generates the filled sequence for the masked regions.""" inputs = tokenizer(masked_sequence, return_tensors="pt") with torch.no_grad(): outputs = esm_model(**inputs) latents = outputs.last_hidden_state.squeeze(0) sigma = torch.rand(1, device=latents.device) noisy_latents = model.forward(latents, sigma) denoised_latents = model.reverse_diffusion(noisy_latents, sigma) filled_sequence = list(masked_sequence) for idx in mask_indices: token_id = torch.argmax(denoised_latents[idx]).item() filled_sequence[idx] = tokenizer.decode([token_id]) return ''.join(filled_sequence) def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length): """Generates a scaffold sequence to connect multiple peptides.""" total_peptide_length = sum(len(peptide) for peptide in peptides) scaffold_length = final_length - total_peptide_length if scaffold_length <= 0: raise ValueError("Final length must be greater than the combined length of the peptides.") scaffold = "[MASK]" * scaffold_length masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:]) inputs = tokenizer(masked_sequence, return_tensors="pt") with torch.no_grad(): outputs = esm_model(**inputs) latents = outputs.last_hidden_state.squeeze(0) sigma = torch.rand(1, device=latents.device) noisy_latents = model.forward(latents, sigma) denoised_latents = model.reverse_diffusion(noisy_latents, sigma) filled_sequence = list(masked_sequence) scaffold_start = len(peptides[0]) scaffold_end = scaffold_start + scaffold_length for idx in range(scaffold_start, scaffold_end): token_id = torch.argmax(denoised_latents[idx]).item() filled_sequence[idx] = tokenizer.decode([token_id]) return ''.join(filled_sequence) def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length): """Generates a de novo protein sequence of the specified length.""" scaffold = "[MASK]" * sequence_length masked_sequence = scaffold inputs = tokenizer(masked_sequence, return_tensors="pt") with torch.no_grad(): outputs = esm_model(**inputs) latents = outputs.last_hidden_state.squeeze(0) sigma = torch.rand(1, device=latents.device) noisy_latents = model.forward(latents, sigma) denoised_latents = model.reverse_diffusion(noisy_latents, sigma) filled_sequence = list(masked_sequence) for idx in range(sequence_length): token_id = torch.argmax(denoised_latents[idx]).item() filled_sequence[idx] = tokenizer.decode([token_id]) return ''.join(filled_sequence) if __name__ == "__main__": import argparse # Argument parsing parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.") subparsers = parser.add_subparsers(dest="mode") # Subparser for the first strategy (multiple peptides to scaffold) parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.") parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.") parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.") # Subparser for the second strategy (fill in regions) parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.") parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.") # Subparser for the third strategy (de novo generation) parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.") parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.") args = parser.parse_args() # Load configurations config = Config() # Load models tokenizer, esm_model = load_esm2_model(config.model_name) diffusion_model = Diffusion.load_from_checkpoint(config.training["save_dir"] + "example.ckpt", config=config, latent_dim=config.latent_dim) diffusion_model.eval() if args.mode == "scaffold": peptides = args.peptides final_length = args.final_length filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length) print(f"Peptides: {' '.join(peptides)}") print(f"Final Length: {final_length}") print(f"Generated Protein: {filled_sequence}") elif args.mode == "fill": sequence = args.sequence masked_sequence, mask_indices = mask_sequence(sequence) filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices) print(f"Original Sequence: {sequence}") print(f"Masked Sequence: {masked_sequence}") print(f"Filled Sequence: {filled_sequence}") elif args.mode == "de_novo": sequence_length = args.sequence_length filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length) print(f"De Novo Sequence Length: {sequence_length}") print(f"Generated Protein: {filled_sequence}")