|
import os |
|
import hydra |
|
import lightning as L |
|
import numpy as np |
|
import omegaconf |
|
import pandas as pd |
|
import rdkit |
|
import rich.syntax |
|
import rich.tree |
|
import torch |
|
from tqdm.auto import tqdm |
|
import esm |
|
import pdb |
|
|
|
import dataloader |
|
import diffusion |
|
from models.classifier import muPPIt |
|
|
|
|
|
rdkit.rdBase.DisableLog('rdApp.error') |
|
|
|
omegaconf.OmegaConf.register_new_resolver( |
|
'cwd', os.getcwd) |
|
omegaconf.OmegaConf.register_new_resolver( |
|
'device_count', torch.cuda.device_count) |
|
omegaconf.OmegaConf.register_new_resolver( |
|
'eval', eval) |
|
omegaconf.OmegaConf.register_new_resolver( |
|
'div_up', lambda x, y: (x + y - 1) // y) |
|
omegaconf.OmegaConf.register_new_resolver( |
|
'if_then_else', |
|
lambda condition, x, y: x if condition else y |
|
) |
|
|
|
vhse8_values = { |
|
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48], |
|
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83], |
|
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80], |
|
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56], |
|
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41], |
|
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41], |
|
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36], |
|
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10], |
|
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65], |
|
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13], |
|
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62], |
|
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01], |
|
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68], |
|
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65], |
|
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56], |
|
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11], |
|
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39], |
|
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85], |
|
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52], |
|
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03], |
|
} |
|
|
|
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7} |
|
|
|
vhse8_tensor = torch.zeros(24, 8) |
|
for aa, values in vhse8_values.items(): |
|
aa_index = aa_to_idx[aa] |
|
vhse8_tensor[aa_index] = torch.tensor(values) |
|
|
|
vhse8_tensor.requires_grad = False |
|
|
|
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
esm_model.eval() |
|
|
|
def precompute_embedding(sequence, tokenizer): |
|
tokens = tokenizer(sequence, return_tensors='pt')['input_ids'] |
|
with torch.no_grad(): |
|
embed = esm_model(tokens, repr_layers=[33], return_contacts=False)["representations"][33] |
|
vhse8_embed = vhse8_tensor[tokens] |
|
return torch.concat([embed, vhse8_embed], dim=-1) |
|
|
|
|
|
@hydra.main(version_base=None, config_path='./configs', |
|
config_name='config') |
|
def main(config: omegaconf.DictConfig) -> None: |
|
|
|
L.seed_everything(config.seed) |
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
|
torch.use_deterministic_algorithms(True) |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
print(f"Checkpoint: {config.eval.checkpoint_path}") |
|
|
|
tokenizer = dataloader.get_tokenizer(config) |
|
|
|
pretrained = diffusion.Diffusion.load_from_checkpoint( |
|
config.eval.checkpoint_path, |
|
tokenizer=tokenizer, |
|
config=config, logger=False) |
|
pretrained.eval() |
|
|
|
muppit = muPPIt(d_node=1288, d_k=32, d_v=32, n_heads=4, lr=None) |
|
muppit.load_state_dict(torch.load(config.guidance.classifier_checkpoint_path)) |
|
muppit.eval() |
|
|
|
mut_embed = precompute_embedding(config.eval.mutant, tokenizer) |
|
wt_embed = precompute_embedding(config.eval.wildtype, tokenizer) |
|
|
|
samples = [] |
|
for _ in tqdm( |
|
range(config.sampling.num_sample_batches), |
|
desc='Gen. batches', leave=False): |
|
sample = pretrained.sample( |
|
wt_embed = wt_embed, |
|
mut_embed = mut_embed, |
|
classifier_model = muppit |
|
) |
|
|
|
samples.extend( |
|
pretrained.tokenizer.batch_decode(sample)) |
|
|
|
print('\n') |
|
print([sample.replace(' ', '')[5:-5] for sample in samples]) |
|
|
|
samples = [sample.replace(' ', '')[5:-5] for sample in samples] |
|
print('\n') |
|
print(samples) |
|
|
|
if __name__ == '__main__': |
|
main() |