muPPIt / ppl.py
AlienChen's picture
Upload 139 files
65bd8af verified
import pandas as pd
from Bio import SeqIO
import io
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch.distributions.categorical import Categorical
import numpy as np
import os
from argparse import ArgumentParser
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Load the model and tokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("ChatterjeeLab/PepMLM-650M")
pepmlm = AutoModelForMaskedLM.from_pretrained("ChatterjeeLab/PepMLM-650M").to(device)
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
"""
For alternative computation of PPL (in batch/matrix format), please check our GitHub repo:
https://github.com/programmablebio/pepmlm/blob/main/scripts/generation.py
"""
sequence = protein_seq + binder_seq
tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
total_loss = 0
# Loop through each token in the binder sequence
for i in range(-len(binder_seq)-1, -1):
# Create a copy of the original tensor
masked_input = tensor_input.clone()
# Mask one token at a time
masked_input[0, i] = tokenizer.mask_token_id
# Create labels
labels = torch.full(tensor_input.shape, -100).to(model.device)
labels[0, i] = tensor_input[0, i]
# Get model prediction and loss
with torch.no_grad():
outputs = model(masked_input, labels=labels)
total_loss += outputs.loss.item()
# Calculate the average loss
avg_loss = total_loss / len(binder_seq)
# Calculate pseudo perplexity
pseudo_perplexity = np.exp(avg_loss)
return pseudo_perplexity
def generate_peptide_for_single_sequence(protein_seq, peptide_length = 15, top_k = 3, num_binders = 4):
peptide_length = int(peptide_length)
top_k = int(top_k)
num_binders = int(num_binders)
binders_with_ppl = []
for _ in range(num_binders):
# Generate binder
masked_peptide = '<mask>' * peptide_length
input_sequence = protein_seq + masked_peptide
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**inputs).logits
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
logits_at_masks = logits[0, mask_token_indices]
# Apply top-k sampling
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
predicted_indices = Categorical(probabilities).sample()
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
# Compute PPL for the generated binder
ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
# Add the generated binder and its PPL to the results list
binders_with_ppl.append([generated_binder, ppl_value])
return binders_with_ppl
def generate_peptide(input_seqs, peptide_length=15, top_k=3, num_binders=4):
if isinstance(input_seqs, str): # Single sequence
binders = generate_peptide_for_single_sequence(input_seqs, peptide_length, top_k, num_binders)
return pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
elif isinstance(input_seqs, list): # List of sequences
results = []
for seq in input_seqs:
binders = generate_peptide_for_single_sequence(seq, peptide_length, top_k, num_binders)
for binder, ppl in binders:
results.append([seq, binder, ppl])
return pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity'])
# binders = ['LKVECMATRVQLECNLCMNV', 'ATKKDERELKSPAEIFQFLF', 'RTIYVQSKIKLSKSQKKSKS', 'AMKQKROLVSAVNKNPAMTK']
# wildtype = 'IVNGEEAVPGSWPWQVSLQDKTGFHFCGGSLINENWVVTAAHCGVTTSDVVVAGEFDQGSSSEKIQKLKIAKVFKNSKYNSLTINNDITLLKLSTAASFSQTVSAVCLPSASDDFAAGTTCVTTGWGLTRY'
# mutant = 'IVNGEEAVPGSWAWQVSLQDKTGFHFCGGSLINENWVVTAAHCGVTTSDVVVAGEFDQGSSSEKIQKLKIAKVFKNSKYNSLTINNDITLLKLSTAASFSQTVSAVCLPSASDDFAAGTTCVTTGWGLTRY'
# for binder in binders:
# wt_ppl = compute_pseudo_perplexity(pepmlm, tokenizer, wildtype, binder)
# mut_ppl = compute_pseudo_perplexity(pepmlm, tokenizer, mutant, binder)
# print(f"{binder}:\n{wt_ppl}\n{mut_ppl}\n")
# print(wt_ppl)
# print(mut_ppl)