|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModel |
|
from diffusion import Diffusion |
|
import config |
|
from esm_utils import load_esm2_model, get_latents |
|
|
|
def mask_sequence(sequence, mask_char='X'): |
|
"""Masks parts of the sequence based on the mask_char.""" |
|
mask_indices = [i for i, char in enumerate(sequence) if char == mask_char] |
|
masked_sequence = sequence.replace(mask_char, '[MASK]') |
|
return masked_sequence, mask_indices |
|
|
|
def generate_filled_sequence(model, tokenizer, esm_model, masked_sequence, mask_indices): |
|
"""Generates the filled sequence for the masked regions.""" |
|
inputs = tokenizer(masked_sequence, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = esm_model(**inputs) |
|
latents = outputs.last_hidden_state.squeeze(0) |
|
|
|
sigma = torch.rand(1, device=latents.device) |
|
noisy_latents = model.forward(latents, sigma) |
|
denoised_latents = model.reverse_diffusion(noisy_latents, sigma) |
|
|
|
filled_sequence = list(masked_sequence) |
|
for idx in mask_indices: |
|
token_id = torch.argmax(denoised_latents[idx]).item() |
|
filled_sequence[idx] = tokenizer.decode([token_id]) |
|
|
|
return ''.join(filled_sequence) |
|
|
|
def generate_scaffold_sequence(model, tokenizer, esm_model, peptides, final_length): |
|
"""Generates a scaffold sequence to connect multiple peptides.""" |
|
total_peptide_length = sum(len(peptide) for peptide in peptides) |
|
scaffold_length = final_length - total_peptide_length |
|
if scaffold_length <= 0: |
|
raise ValueError("Final length must be greater than the combined length of the peptides.") |
|
|
|
scaffold = "[MASK]" * scaffold_length |
|
masked_sequence = "".join(peptides[:1] + [scaffold] + peptides[1:]) |
|
|
|
inputs = tokenizer(masked_sequence, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = esm_model(**inputs) |
|
latents = outputs.last_hidden_state.squeeze(0) |
|
|
|
sigma = torch.rand(1, device=latents.device) |
|
noisy_latents = model.forward(latents, sigma) |
|
denoised_latents = model.reverse_diffusion(noisy_latents, sigma) |
|
|
|
filled_sequence = list(masked_sequence) |
|
scaffold_start = len(peptides[0]) |
|
scaffold_end = scaffold_start + scaffold_length |
|
for idx in range(scaffold_start, scaffold_end): |
|
token_id = torch.argmax(denoised_latents[idx]).item() |
|
filled_sequence[idx] = tokenizer.decode([token_id]) |
|
|
|
return ''.join(filled_sequence) |
|
|
|
def generate_de_novo_sequence(model, tokenizer, esm_model, sequence_length): |
|
"""Generates a de novo protein sequence of the specified length.""" |
|
scaffold = "[MASK]" * sequence_length |
|
masked_sequence = scaffold |
|
|
|
inputs = tokenizer(masked_sequence, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = esm_model(**inputs) |
|
latents = outputs.last_hidden_state.squeeze(0) |
|
|
|
sigma = torch.rand(1, device=latents.device) |
|
noisy_latents = model.forward(latents, sigma) |
|
denoised_latents = model.reverse_diffusion(noisy_latents, sigma) |
|
|
|
filled_sequence = list(masked_sequence) |
|
for idx in range(sequence_length): |
|
token_id = torch.argmax(denoised_latents[idx]).item() |
|
filled_sequence[idx] = tokenizer.decode([token_id]) |
|
|
|
return ''.join(filled_sequence) |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Generate protein sequences using latent diffusion model.") |
|
subparsers = parser.add_subparsers(dest="mode") |
|
|
|
|
|
parser_scaffold = subparsers.add_parser("scaffold", help="Generate scaffold to connect multiple peptides.") |
|
parser_scaffold.add_argument("peptides", nargs='+', help="Peptides to connect.") |
|
parser_scaffold.add_argument("final_length", type=int, help="Final length of the protein sequence.") |
|
|
|
|
|
parser_fill = subparsers.add_parser("fill", help="Fill in specified regions in a given protein sequence.") |
|
parser_fill.add_argument("sequence", help="Protein sequence with regions to fill specified by 'X'.") |
|
|
|
|
|
parser_de_novo = subparsers.add_parser("de_novo", help="Generate a de novo protein sequence.") |
|
parser_de_novo.add_argument("sequence_length", type=int, help="Length of the de novo generated protein sequence.") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
tokenizer, esm_model = load_esm2_model(config.MODEL_NAME) |
|
diffusion_model = Diffusion.load_from_checkpoint(config.Training.SAVE_DIR + "best_model.ckpt", config=config, latent_dim=config.LATENT_DIM) |
|
diffusion_model.eval() |
|
|
|
if args.mode == "scaffold": |
|
peptides = args.peptides |
|
final_length = args.final_length |
|
filled_sequence = generate_scaffold_sequence(diffusion_model, tokenizer, esm_model, peptides, final_length) |
|
print(f"Peptides: {' '.join(peptides)}") |
|
print(f"Final Length: {final_length}") |
|
print(f"Generated Protein: {filled_sequence}") |
|
|
|
elif args.mode == "fill": |
|
sequence = args.sequence |
|
masked_sequence, mask_indices = mask_sequence(sequence) |
|
filled_sequence = generate_filled_sequence(diffusion_model, tokenizer, esm_model, masked_sequence, mask_indices) |
|
print(f"Original Sequence: {sequence}") |
|
print(f"Masked Sequence: {masked_sequence}") |
|
print(f"Filled Sequence: {filled_sequence}") |
|
|
|
elif args.mode == "de_novo": |
|
sequence_length = args.sequence_length |
|
filled_sequence = generate_de_novo_sequence(diffusion_model, tokenizer, esm_model, sequence_length) |
|
print(f"De Novo Sequence Length: {sequence_length}") |
|
print(f"Generated Protein: {filled_sequence}") |
|
|
|
|