Spaces:
Running
Running
import copy | |
import re | |
import os.path | |
import torch | |
import sys | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import matplotlib.pyplot as plt | |
import plotly.graph_objects as go | |
import tempfile | |
import requests | |
from moleculekit.molecule import Molecule | |
sys.path.append("/home/user/app/ProteinMPNN/vanilla_proteinmpnn") | |
# this is for local | |
sys.path.append(os.path.join(os.getcwd(), "ProteinMPNN/vanilla_proteinmpnn")) | |
def make_tied_positions_for_homomers(pdb_dict_list): | |
my_dict = {} | |
for result in pdb_dict_list: | |
all_chain_list = sorted( | |
[item[-1:] for item in list(result) if item[:9] == "seq_chain"] | |
) # A, B, C, ... | |
tied_positions_list = [] | |
chain_length = len(result[f"seq_chain_{all_chain_list[0]}"]) | |
for i in range(1, chain_length + 1): | |
temp_dict = {} | |
for j, chain in enumerate(all_chain_list): | |
temp_dict[chain] = [i] # needs to be a list | |
tied_positions_list.append(temp_dict) | |
my_dict[result["name"]] = tied_positions_list | |
return my_dict | |
def align_structures(pdb1, pdb2, index): | |
"""Take two structure and superimpose pdb1 on pdb2""" | |
import Bio.PDB | |
import subprocess | |
pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
# Get the structures | |
ref_structure = pdb_parser.get_structure("ref", pdb1) | |
sample_structure = pdb_parser.get_structure("sample", pdb2) | |
sample_structure_ca = [ | |
atom for atom in sample_structure.get_atoms() if atom.name == "CA" | |
] | |
plddts = [atom.get_bfactor() for atom in sample_structure_ca] | |
aligner = Bio.PDB.CEAligner() | |
aligner.set_reference(ref_structure) | |
aligner.align(sample_structure) | |
io = Bio.PDB.PDBIO() | |
io.set_structure(ref_structure) | |
hash = os.path.splitext(os.path.basename(pdb2))[0] | |
io.save(f"outputs/{hash}_ref_{index}.pdb") | |
io.set_structure(sample_structure) | |
io.save(f"outputs/{hash}_align_{index}.pdb") | |
# Doing this to get around biopython CEALIGN bug | |
# subprocess.call("pymol -c -Q -r cealign.pml", shell=True) | |
return ( | |
aligner.rms, | |
f"outputs/{hash}_ref_{index}.pdb", | |
f"outputs/{hash}_align_{index}.pdb", | |
plddts, | |
) | |
if not os.path.exists("/home/user/app/ProteinMPNN/"): | |
path_to_model_weights = os.path.join( | |
os.getcwd(), "ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights" | |
) | |
is_local = True | |
else: | |
path_to_model_weights = ( | |
"/home/user/app/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights" | |
) | |
is_local = False | |
if is_local: | |
print("Running locally") | |
from transformers import AutoTokenizer, EsmForProteinFolding | |
def setup_proteinmpnn(model_name="v_48_020", backbone_noise=0.00): | |
from protein_mpnn_utils import ( | |
loss_nll, | |
loss_smoothed, | |
gather_edges, | |
gather_nodes, | |
gather_nodes_t, | |
cat_neighbors_nodes, | |
_scores, | |
_S_to_seq, | |
tied_featurize, | |
parse_PDB, | |
) | |
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN | |
device = torch.device( | |
"cpu" | |
) # torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") #fix for memory issues | |
# ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030, v_32_002, v_32_010; v_32_020, v_32_030; v_48_010=version with 48 edges 0.10A noise | |
# Standard deviation of Gaussian noise to add to backbone atoms | |
hidden_dim = 128 | |
num_layers = 3 | |
model_folder_path = path_to_model_weights | |
if model_folder_path[-1] != "/": | |
model_folder_path = model_folder_path + "/" | |
checkpoint_path = model_folder_path + f"{model_name}.pt" | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
noise_level_print = checkpoint["noise_level"] | |
model = ProteinMPNN( | |
num_letters=21, | |
node_features=hidden_dim, | |
edge_features=hidden_dim, | |
hidden_dim=hidden_dim, | |
num_encoder_layers=num_layers, | |
num_decoder_layers=num_layers, | |
augment_eps=backbone_noise, | |
k_neighbors=checkpoint["num_edges"], | |
) | |
model.to(device) | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
model.eval() | |
return model, device | |
def get_pdb(pdb_code="", filepath=""): | |
if pdb_code is None or pdb_code == "": | |
try: | |
return filepath.name | |
except AttributeError as e: | |
return None | |
else: | |
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") | |
return f"{pdb_code}.pdb" | |
def preprocess_mol(pdb_code="", filepath=""): | |
if pdb_code is None or pdb_code == "": | |
try: | |
mol = Molecule(filepath.name) | |
except AttributeError as e: | |
return None | |
else: | |
mol = Molecule(pdb_code) | |
mol.write("original.pdb") | |
# clean messy files and only include protein itself | |
mol.filter("protein") | |
# renumber using moleculekit 0...len(protein) | |
df = mol.renumberResidues(returnMapping=True) | |
# add proteinMPNN index col which used 1..len(chain), 1...len(chain) | |
indexes = [] | |
for chain, g in df.groupby("chain"): | |
j = 1 | |
for i, row in g.iterrows(): | |
indexes.append(j) | |
j += 1 | |
df["proteinMPNN_index"] = indexes | |
mol.write("cleaned.pdb") | |
return "cleaned.pdb", df | |
def assign_sasa(mol): | |
from moleculekit.projections.metricsasa import MetricSasa | |
metr = MetricSasa(mode="residue", filtersel="protein") | |
sasaR = metr.project(mol)[0] | |
is_prot = mol.atomselect("protein") | |
resids = pd.DataFrame.from_dict({"resid": mol.resid, "is_prot": is_prot}) | |
new_masses = [] | |
i_without_non_prot = 0 | |
for i, g in resids.groupby((resids["resid"].shift() != resids["resid"]).cumsum()): | |
if g["is_prot"].unique()[0] == True: | |
g["sasa"] = sasaR[i_without_non_prot] | |
i_without_non_prot += 1 | |
else: | |
g["sasa"] = 0 | |
new_masses.extend(list(g.sasa)) | |
return np.array(new_masses) | |
def process_atomsel(atomsel): | |
"""everything lowercase and replace some keywords not relevant for protein design""" | |
atomsel = re.sub("sasa", "mass", atomsel, flags=re.I) | |
atomsel = re.sub("plddt", "beta", atomsel, flags=re.I) | |
return atomsel | |
def make_fixed_positions_dict(atomsel, residue_index_df): | |
# we use the uploaded file for the selection | |
mol = Molecule("original.pdb") | |
# use index for selection as resids will change | |
# set sasa to 0 for all non protein atoms (all non protein atoms are deleted later) | |
mol.masses = assign_sasa(mol) | |
print(mol.masses.shape) | |
print(assign_sasa(mol).shape) | |
atomsel = process_atomsel(atomsel) | |
selected_residues = mol.get("index", atomsel) | |
# clean up | |
mol.filter("protein") | |
mol.renumberResidues() | |
# based on selected index now get resids | |
selected_residues = [str(i) for i in selected_residues] | |
if len(selected_residues) == 0: | |
return None, [] | |
selected_residues_str = " ".join(selected_residues) | |
selected_residues = set(mol.get("resid", sel=f"index {selected_residues_str}")) | |
# use the proteinMPNN index nomenclature to assemble fixed_positions_dict | |
fixed_positions_df = residue_index_df[ | |
residue_index_df["new_resid"].isin(selected_residues) | |
] | |
chains = set(mol.get("chain", sel="all")) | |
fixed_position_dict = {"cleaned": {}} | |
# store the selected residues in a list for the visualization later with cleaned.pdb | |
selected_residues = list(fixed_positions_df["new_resid"]) | |
for c in chains: | |
fixed_position_dict["cleaned"][c] = [] | |
for i, row in fixed_positions_df.iterrows(): | |
fixed_position_dict["cleaned"][row["chain"]].append(row["proteinMPNN_index"]) | |
return fixed_position_dict, selected_residues | |
def update( | |
inp, | |
file, | |
designed_chain, | |
fixed_chain, | |
homomer, | |
num_seqs, | |
sampling_temp, | |
model_name, | |
backbone_noise, | |
atomsel, | |
): | |
from protein_mpnn_utils import ( | |
loss_nll, | |
loss_smoothed, | |
gather_edges, | |
gather_nodes, | |
gather_nodes_t, | |
cat_neighbors_nodes, | |
_scores, | |
_S_to_seq, | |
tied_featurize, | |
parse_PDB, | |
) | |
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN | |
# pdb_path = get_pdb(pdb_code=inp, filepath=file) | |
pdb_path, mol_index = preprocess_mol(pdb_code=inp, filepath=file) | |
if pdb_path == None: | |
return "Error processing PDB" | |
model, device = setup_proteinmpnn( | |
model_name=model_name, backbone_noise=backbone_noise | |
) | |
if designed_chain == "": | |
designed_chain_list = [] | |
else: | |
designed_chain_list = re.sub("[^A-Za-z]+", ",", designed_chain).split(",") | |
if fixed_chain == "": | |
fixed_chain_list = [] | |
else: | |
fixed_chain_list = re.sub("[^A-Za-z]+", ",", fixed_chain).split(",") | |
chain_list = list(set(designed_chain_list + fixed_chain_list)) | |
num_seq_per_target = num_seqs | |
save_score = 0 # 0 for False, 1 for True; save score=-log_prob to npy files | |
save_probs = ( | |
0 # 0 for False, 1 for True; save MPNN predicted probabilites per position | |
) | |
score_only = 0 # 0 for False, 1 for True; score input backbone-sequence pairs | |
conditional_probs_only = 0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone) | |
conditional_probs_only_backbone = 0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone) | |
batch_size = 1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory | |
max_length = 20000 # Max sequence length | |
out_folder = "." # Path to a folder to output sequences, e.g. /home/out/ | |
jsonl_path = "" # Path to a folder with parsed pdb into jsonl | |
omit_AAs = "X" # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine. | |
pssm_multi = 0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions | |
pssm_threshold = 0.0 # A value between -inf + inf to restric per position AAs | |
pssm_log_odds_flag = 0 # 0 for False, 1 for True | |
pssm_bias_flag = 0 # 0 for False, 1 for True | |
folder_for_outputs = out_folder | |
NUM_BATCHES = num_seq_per_target // batch_size | |
BATCH_COPIES = batch_size | |
temperatures = [sampling_temp] | |
omit_AAs_list = omit_AAs | |
alphabet = "ACDEFGHIKLMNPQRSTVWYX" | |
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32) | |
chain_id_dict = None | |
if atomsel == "": | |
fixed_positions_dict, selected_residues = None, [] | |
else: | |
fixed_positions_dict, selected_residues = make_fixed_positions_dict( | |
atomsel, mol_index | |
) | |
pssm_dict = None | |
omit_AA_dict = None | |
bias_AA_dict = None | |
bias_by_res_dict = None | |
bias_AAs_np = np.zeros(len(alphabet)) | |
############################################################### | |
pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list) | |
dataset_valid = StructureDatasetPDB( | |
pdb_dict_list, truncate=None, max_length=max_length | |
) | |
if homomer: | |
tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list) | |
else: | |
tied_positions_dict = None | |
chain_id_dict = {} | |
chain_id_dict[pdb_dict_list[0]["name"]] = (designed_chain_list, fixed_chain_list) | |
with torch.no_grad(): | |
for ix, prot in enumerate(dataset_valid): | |
score_list = [] | |
all_probs_list = [] | |
all_log_probs_list = [] | |
S_sample_list = [] | |
batch_clones = [copy.deepcopy(prot) for i in range(BATCH_COPIES)] | |
( | |
X, | |
S, | |
mask, | |
lengths, | |
chain_M, | |
chain_encoding_all, | |
chain_list_list, | |
visible_list_list, | |
masked_list_list, | |
masked_chain_length_list_list, | |
chain_M_pos, | |
omit_AA_mask, | |
residue_idx, | |
dihedral_mask, | |
tied_pos_list_of_lists_list, | |
pssm_coef, | |
pssm_bias, | |
pssm_log_odds_all, | |
bias_by_res_all, | |
tied_beta, | |
) = tied_featurize( | |
batch_clones, | |
device, | |
chain_id_dict, | |
fixed_positions_dict, | |
omit_AA_dict, | |
tied_positions_dict, | |
pssm_dict, | |
bias_by_res_dict, | |
) | |
pssm_log_odds_mask = ( | |
pssm_log_odds_all > pssm_threshold | |
).float() # 1.0 for true, 0.0 for false | |
name_ = batch_clones[0]["name"] | |
randn_1 = torch.randn(chain_M.shape, device=X.device) | |
log_probs = model( | |
X, | |
S, | |
mask, | |
chain_M * chain_M_pos, | |
residue_idx, | |
chain_encoding_all, | |
randn_1, | |
) | |
mask_for_loss = mask * chain_M * chain_M_pos | |
scores = _scores(S, log_probs, mask_for_loss) | |
native_score = scores.cpu().data.numpy() | |
message = "" | |
seq_list = [] | |
seq_recovery = [] | |
seq_score = [] | |
for temp in temperatures: | |
for j in range(NUM_BATCHES): | |
randn_2 = torch.randn(chain_M.shape, device=X.device) | |
if tied_positions_dict == None: | |
sample_dict = model.sample( | |
X, | |
randn_2, | |
S, | |
chain_M, | |
chain_encoding_all, | |
residue_idx, | |
mask=mask, | |
temperature=temp, | |
omit_AAs_np=omit_AAs_np, | |
bias_AAs_np=bias_AAs_np, | |
chain_M_pos=chain_M_pos, | |
omit_AA_mask=omit_AA_mask, | |
pssm_coef=pssm_coef, | |
pssm_bias=pssm_bias, | |
pssm_multi=pssm_multi, | |
pssm_log_odds_flag=bool(pssm_log_odds_flag), | |
pssm_log_odds_mask=pssm_log_odds_mask, | |
pssm_bias_flag=bool(pssm_bias_flag), | |
bias_by_res=bias_by_res_all, | |
) | |
S_sample = sample_dict["S"] | |
else: | |
sample_dict = model.tied_sample( | |
X, | |
randn_2, | |
S, | |
chain_M, | |
chain_encoding_all, | |
residue_idx, | |
mask=mask, | |
temperature=temp, | |
omit_AAs_np=omit_AAs_np, | |
bias_AAs_np=bias_AAs_np, | |
chain_M_pos=chain_M_pos, | |
omit_AA_mask=omit_AA_mask, | |
pssm_coef=pssm_coef, | |
pssm_bias=pssm_bias, | |
pssm_multi=pssm_multi, | |
pssm_log_odds_flag=bool(pssm_log_odds_flag), | |
pssm_log_odds_mask=pssm_log_odds_mask, | |
pssm_bias_flag=bool(pssm_bias_flag), | |
tied_pos=tied_pos_list_of_lists_list[0], | |
tied_beta=tied_beta, | |
bias_by_res=bias_by_res_all, | |
) | |
# Compute scores | |
S_sample = sample_dict["S"] | |
log_probs = model( | |
X, | |
S_sample, | |
mask, | |
chain_M * chain_M_pos, | |
residue_idx, | |
chain_encoding_all, | |
randn_2, | |
use_input_decoding_order=True, | |
decoding_order=sample_dict["decoding_order"], | |
) | |
mask_for_loss = mask * chain_M * chain_M_pos | |
scores = _scores(S_sample, log_probs, mask_for_loss) | |
scores = scores.cpu().data.numpy() | |
all_probs_list.append(sample_dict["probs"].cpu().data.numpy()) | |
all_log_probs_list.append(log_probs.cpu().data.numpy()) | |
S_sample_list.append(S_sample.cpu().data.numpy()) | |
for b_ix in range(BATCH_COPIES): | |
masked_chain_length_list = masked_chain_length_list_list[b_ix] | |
masked_list = masked_list_list[b_ix] | |
seq_recovery_rate = torch.sum( | |
torch.sum( | |
torch.nn.functional.one_hot(S[b_ix], 21) | |
* torch.nn.functional.one_hot(S_sample[b_ix], 21), | |
axis=-1, | |
) | |
* mask_for_loss[b_ix] | |
) / torch.sum(mask_for_loss[b_ix]) | |
seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix]) | |
score = scores[b_ix] | |
score_list.append(score) | |
native_seq = _S_to_seq(S[b_ix], chain_M[b_ix]) | |
if b_ix == 0 and j == 0 and temp == temperatures[0]: | |
start = 0 | |
end = 0 | |
list_of_AAs = [] | |
for mask_l in masked_chain_length_list: | |
end += mask_l | |
list_of_AAs.append(native_seq[start:end]) | |
start = end | |
native_seq = "".join( | |
list(np.array(list_of_AAs)[np.argsort(masked_list)]) | |
) | |
l0 = 0 | |
for mc_length in list( | |
np.array(masked_chain_length_list)[ | |
np.argsort(masked_list) | |
] | |
)[:-1]: | |
l0 += mc_length | |
native_seq = native_seq[:l0] + "/" + native_seq[l0:] | |
l0 += 1 | |
sorted_masked_chain_letters = np.argsort( | |
masked_list_list[0] | |
) | |
print_masked_chains = [ | |
masked_list_list[0][i] | |
for i in sorted_masked_chain_letters | |
] | |
sorted_visible_chain_letters = np.argsort( | |
visible_list_list[0] | |
) | |
print_visible_chains = [ | |
visible_list_list[0][i] | |
for i in sorted_visible_chain_letters | |
] | |
native_score_print = np.format_float_positional( | |
np.float32(native_score.mean()), | |
unique=False, | |
precision=4, | |
) | |
line = ">{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n".format( | |
name_, | |
native_score_print, | |
print_visible_chains, | |
print_masked_chains, | |
model_name, | |
native_seq, | |
) | |
message += f"{line}\n" | |
start = 0 | |
end = 0 | |
list_of_AAs = [] | |
for mask_l in masked_chain_length_list: | |
end += mask_l | |
list_of_AAs.append(seq[start:end]) | |
start = end | |
seq = "".join( | |
list(np.array(list_of_AAs)[np.argsort(masked_list)]) | |
) | |
# add non designed chains to predicted sequence | |
l0 = 0 | |
for mc_length in list( | |
np.array(masked_chain_length_list)[np.argsort(masked_list)] | |
)[:-1]: | |
l0 += mc_length | |
seq = seq[:l0] + "/" + seq[l0:] | |
l0 += 1 | |
score_print = np.format_float_positional( | |
np.float32(score), unique=False, precision=4 | |
) | |
seq_rec_print = np.format_float_positional( | |
np.float32(seq_recovery_rate.detach().cpu().numpy()), | |
unique=False, | |
precision=4, | |
) | |
chain_s = "" | |
if len(visible_list_list[0]) > 0: | |
chain_M_bool = chain_M.bool() | |
not_designed = _S_to_seq(S[b_ix], ~chain_M_bool[b_ix]) | |
labels = ( | |
chain_encoding_all[b_ix][~chain_M_bool[b_ix]] | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
for c in set(labels): | |
chain_s += ":" | |
nd_mask = labels == c | |
for i, x in enumerate(not_designed): | |
if nd_mask[i]: | |
chain_s += x | |
seq_recovery.append(seq_rec_print) | |
seq_score.append(score_print) | |
line = ( | |
">T={}, sample={}, score={}, seq_recovery={}\n{}\n".format( | |
temp, b_ix, score_print, seq_rec_print, seq | |
) | |
) | |
seq_list.append(seq + chain_s) | |
message += f"{line}\n" | |
if fixed_positions_dict != None: | |
message += f"\nfixed positions:* {fixed_positions_dict['cleaned']} \n\n*uses CHAIN:[1..len(chain)] residue numbering" | |
# somehow sequences still contain X, remove again | |
for i, x in enumerate(seq_list): | |
for aa in omit_AAs: | |
seq_list[i] = x.replace(aa, "") | |
all_probs_concat = np.concatenate(all_probs_list) | |
all_log_probs_concat = np.concatenate(all_log_probs_list) | |
np.savetxt("all_probs_concat.csv", all_probs_concat.mean(0).T, delimiter=",") | |
np.savetxt( | |
"all_log_probs_concat.csv", | |
np.exp(all_log_probs_concat).mean(0).T, | |
delimiter=",", | |
) | |
S_sample_concat = np.concatenate(S_sample_list) | |
fig = px.imshow( | |
np.exp(all_log_probs_concat).mean(0).T, | |
labels=dict(x="positions", y="amino acids", color="probability"), | |
y=list(alphabet), | |
template="simple_white", | |
) | |
fig.update_xaxes(side="top") | |
fig_tadjusted = px.imshow( | |
all_probs_concat.mean(0).T, | |
labels=dict(x="positions", y="amino acids", color="probability"), | |
y=list(alphabet), | |
template="simple_white", | |
) | |
fig_tadjusted.update_xaxes(side="top") | |
seq_dict = {"seq_list": seq_list, "recovery": seq_recovery, "seq_score": seq_score} | |
mol = structure_pred(seq_dict, pdb_path, selected_residues) | |
print(seq_list) | |
return ( | |
message, | |
fig, | |
fig_tadjusted, | |
gr.File.update(value="all_log_probs_concat.csv", visible=True), | |
gr.File.update(value="all_probs_concat.csv", visible=True), | |
pdb_path, | |
gr.Dropdown.update(choices=seq_list, value=seq_list[0], interactive=True), | |
selected_residues, | |
seq_dict, | |
mol, | |
) | |
def updateseq(seq, seq_dict, pdb_path, selected_residues): | |
# find index of seq in seq_dict | |
seq_list = seq_dict["seq_list"] | |
seq_index = seq_list.index(seq) | |
print(seq, seq_index) | |
mol = structure_pred(seq_dict, pdb_path, selected_residues, index=seq_index) | |
return mol | |
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein | |
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 | |
def convert_outputs_to_pdb(outputs): | |
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) | |
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} | |
final_atom_positions = final_atom_positions.cpu().numpy() | |
final_atom_mask = outputs["atom37_atom_exists"] | |
pdbs = [] | |
for i in range(outputs["aatype"].shape[0]): | |
aa = outputs["aatype"][i] | |
pred_pos = final_atom_positions[i] | |
mask = final_atom_mask[i] | |
resid = outputs["residue_index"][i] + 1 | |
pred = OFProtein( | |
aatype=aa, | |
atom_positions=pred_pos, | |
atom_mask=mask, | |
residue_index=resid, | |
b_factors=outputs["plddt"][i], | |
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, | |
) | |
pdbs.append(to_pdb(pred)) | |
return pdbs | |
def get_esmfold_local(sequence): | |
filename = "outputs/" + hashlib.md5(str.encode(sequence)).hexdigest() + ".pdb" | |
if not os.path.exists(filename): | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
model = EsmForProteinFolding.from_pretrained( | |
"facebook/esmfold_v1", low_cpu_mem_usage=True | |
) | |
model = model.cuda() | |
model.esm = model.esm.half() | |
import torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
model.trunk.set_chunk_size(64) | |
position_id_offsets = [] | |
linker_mask = [] | |
for i, s in enumerate(sequence.split("/")): | |
linker = 25 if i < sequence.count("/") else 0 | |
offsets = [i * 512] * (len(s) + linker) | |
linker_mask.extend([1] * len(s) + [0] * linker) | |
position_id_offsets.extend(offsets) | |
sequence = sequence.replace("/", "G" * 25) | |
tokenized = tokenizer([sequence], return_tensors="pt", add_special_tokens=False) | |
with torch.no_grad(): | |
position_ids = torch.arange(len(sequence), dtype=torch.long) | |
position_ids = position_ids + torch.torch.LongTensor(position_id_offsets) | |
linker_mask = torch.Tensor(linker_mask).unsqueeze(1) | |
tokenized["position_ids"] = position_ids.unsqueeze(0) | |
tokenized = {key: tensor.cuda() for key, tensor in tokenized.items()} | |
with torch.no_grad(): | |
output = model(**tokenized) | |
output["atom37_atom_exists"] = output["atom37_atom_exists"] * linker_mask.to( | |
output["atom37_atom_exists"].device | |
) | |
pdb = convert_outputs_to_pdb(output) | |
with open(filename, "w+") as f: | |
f.write("".join(pdb)) | |
print("local prediction", filename) | |
else: | |
print("prediction already on disk") | |
return filename | |
def structure_pred(seq_dict, pdb, selectedResidues, index=0): | |
allSeqs = seq_dict["seq_list"] | |
lenSeqs = len(allSeqs) | |
if len(allSeqs[index]) > 400: | |
return """ | |
<div class="p-4 mb-4 text-sm text-yellow-700 bg-orange-50 rounded-lg" role="alert"> | |
<span class="font-medium">Sorry!</span> Currently only small proteins <400 aa can be predicted with the web api of ESMFold</div> | |
""" | |
if "/" in allSeqs[index] and not is_local: | |
return """ | |
<div class="p-4 mb-4 text-sm text-yellow-700 bg-orange-50 rounded-lg" role="alert"> | |
<span class="font-medium">Sorry!</span> Sequence is multimeric and no structure prediction is run. Use local copy of ESMFold to predict.</div> | |
""" | |
i = 0 | |
sequences = {} | |
if is_local: | |
pdb_file = get_esmfold_local(allSeqs[index]) | |
else: | |
pdb_file = get_esmfold(allSeqs[index]) | |
rms, input_pdb, aligned_pdb, plddts = align_structures(pdb, pdb_file, index) | |
sequences[i] = { | |
"Seq": index, | |
"RMSD": f"{rms:.2f}", | |
"Score": seq_dict["seq_score"][i], | |
"Recovery": seq_dict["recovery"][i], | |
"Mean pLDDT": f"{np.mean(plddts):.4f}", | |
} | |
num_res = len(allSeqs[index]) | |
return molecule( | |
input_pdb, | |
aligned_pdb, | |
lenSeqs, | |
num_res, | |
selectedResidues, | |
allSeqs, | |
sequences, | |
) | |
def read_mol(molpath): | |
with open(molpath, "r") as fp: | |
lines = fp.readlines() | |
mol = "" | |
for l in lines: | |
mol += l | |
return mol | |
def molecule( | |
input_pdb, aligned_pdb, lenSeqs, num_res, selectedResidues, allSeqs, sequences | |
): | |
print("mol updated") | |
print("filenames", input_pdb, aligned_pdb) | |
mol = read_mol(input_pdb) | |
options = "" | |
pred_mol = "[" | |
seqdata = "{" | |
selected = "selected" | |
for i in range(1): # lenSeqs): | |
seqdata += ( | |
str(sequences[i]["Seq"]) | |
+ ': { "score": ' | |
+ sequences[i]["Score"] | |
+ ', "rmsd": ' | |
+ sequences[i]["RMSD"] | |
+ ', "recovery": ' | |
+ sequences[i]["Recovery"] | |
+ ', "plddt": ' | |
+ sequences[i]["Mean pLDDT"] | |
+ ', "seq":"' | |
+ allSeqs[i] | |
+ '"}' | |
) | |
pred_mol += f"`{read_mol(aligned_pdb)}`" | |
selected = "" | |
# if i != lenSeqs - 1: | |
# pred_mol += "," | |
# seqdata += "," | |
pred_mol += "]" | |
seqdata += "}" | |
x = ( | |
"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<link rel="stylesheet" href="https://unpkg.com/[email protected]/dist/flowbite.min.css" /> | |
<style> | |
body{ | |
font-family:sans-serif | |
} | |
.mol-container { | |
width: 100%; | |
height: 700px; | |
position: relative; | |
} | |
.space-x-2 > * + *{ | |
margin-left: 0.5rem; | |
} | |
.p-1{ | |
padding:0.5rem; | |
} | |
.w-4{ | |
width:1rem; | |
} | |
.h-4{ | |
height:1rem; | |
} | |
.mt-4{ | |
margin-top:1rem; | |
} | |
.mol-container select{ | |
background-image:None; | |
} | |
</style> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div class="font-mono bg-gray-100 py-3 px-2 font-sm rounded"> | |
<code>> seq <span id="id"></span>, score <span id="score"></span>, RMSD <span id="seqrmsd"></span>, Recovery | |
<span id="recovery"></span>, pLDDT <span id="plddt"></span></code><br> | |
<p id="seqText" class="max-w-4xl font-xs block" style="word-break: break-all;"> | |
</p> | |
</div> | |
<div id="container" class="mol-container"></div> | |
<div class="flex items-center"> | |
<div class="px-4 pt-2"> | |
<label for="sidechain" class="relative inline-flex items-center mb-4 cursor-pointer "> | |
<input id="sidechain" type="checkbox" class="sr-only peer"> | |
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div> | |
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show side chains</span> | |
</label> | |
</div> | |
<div class="px-4 pt-2"> | |
<label for="startstructure" class="relative inline-flex items-center mb-4 cursor-pointer "> | |
<input id="startstructure" type="checkbox" class="sr-only peer" checked> | |
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div> | |
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show input structure</span> | |
</label> | |
</div> | |
<button type="button" class="text-gray-900 bg-white hover:bg-gray-100 border border-gray-200 focus:ring-4 focus:outline-none focus:ring-gray-100 font-medium rounded-lg text-sm px-5 py-2.5 text-center inline-flex items-center dark:focus:ring-gray-600 dark:bg-gray-800 dark:border-gray-700 dark:text-white dark:hover:bg-gray-700 mr-2 mb-2" id="download"> | |
<svg class="w-6 h-6 mr-2 -ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path></svg> | |
Download predicted structure | |
</button> | |
</div> | |
<div class="text-sm"> | |
<div> RMSD ESMFold vs. native: <span id="rmsd"></span> Å computed using CEAlign on the aligned fragment</div> | |
</div> | |
<div class="text-sm flex items-start"> | |
<div class="w-1/2"> | |
<div class="font-medium mt-4 flex items-center space-x-2"><b>AF2 model of redesigned sequence</b></div> | |
<div>ESMFold model confidence:</div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" style="background-color: rgb(0, 83, 214);"> </span><span class="legendlabel">Very high | |
(pLDDT > 90)</span></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" style="background-color: rgb(101, 203, 243);"> </span><span class="legendlabel">Confident | |
(90 > pLDDT > 70)</span></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" style="background-color: rgb(255, 219, 19);"> </span><span class="legendlabel">Low (70 > | |
pLDDT > 50)</span></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" style="background-color: rgb(255, 125, 69);"> </span><span class="legendlabel">Very low | |
(pLDDT < 50)</span></div> | |
<div class="row column legendDesc"> ESMFold produces a per-residue confidence | |
score (pLDDT) between 0 and 100. Some regions below 50 pLDDT may be unstructured in isolation. | |
</div> | |
</div> | |
<div class="w-1/2"> | |
<div class="font-medium mt-4 flex items-center space-x-2"><b>Input structure </b><span class="w-4 h-4 bg-gray-300 inline-flex" ></span></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" style="background-color:hotpink" > </span><span class="legendlabel">Fixed positions</span></div> | |
</div> | |
</div> | |
<script> | |
function drawStructures(i, selectedResidues) { | |
$("#rmsd").text(seqs[i]["rmsd"]) | |
$("#seqText").text(seqs[i]["seq"]) | |
$("#seqrmsd").text(seqs[i]["rmsd"]) | |
$("#id").text(i) | |
$("#score").text(seqs[i]["score"]) | |
$("#recovery").text(seqs[i]["recovery"]) | |
$("#plddt").text(seqs[i]["plddt"]) | |
viewer = $3Dmol.createViewer(element, config); | |
viewer.addModel(data[0], "pdb"); | |
viewer.addModel(pdb, "pdb"); | |
viewer.getModel(1).setStyle({}, { cartoon: { colorscheme: { prop: "resi", map: colors } } }) | |
viewer.getModel(0).setStyle({}, { cartoon: { colorfunc: colorAlpha } }); | |
viewer.zoomTo(); | |
viewer.render(); | |
viewer.zoom(0.8, 2000); | |
viewer.getModel(0).setHoverable({}, true, | |
function (atom, viewer, event, container) { | |
if (!atom.label) { | |
atom.label = viewer.addLabel(atom.resn + atom.resi + " pLDDT=" + atom.b, { position: atom, backgroundColor: "mintcream", fontColor: "black" }); | |
} | |
}, | |
function (atom, viewer) { | |
if (atom.label) { | |
viewer.removeLabel(atom.label); | |
delete atom.label; | |
} | |
} | |
); | |
} | |
let viewer = null; | |
let voldata = null; | |
let element = null; | |
let config = null; | |
let currentIndex = """ | |
+ str(sequences[i]["Seq"]) | |
+ """; | |
let seqs = """ | |
+ seqdata | |
+ """ | |
let data = """ | |
+ pred_mol | |
+ """ | |
let pdb = `""" | |
+ mol | |
+ """` | |
var selectedResidues = """ | |
+ f"{selectedResidues}" | |
+ """ | |
//AlphaFold code from https://gist.github.com/piroyon/30d1c1099ad488a7952c3b21a5bebc96 | |
let colorAlpha = function (atom) { | |
if (atom.b < 0.50) { | |
return "OrangeRed"; | |
} else if (atom.b < 0.70) { | |
return "Gold"; | |
} else if (atom.b < 0.90) { | |
return "MediumTurquoise"; | |
} else { | |
return "Blue"; | |
} | |
}; | |
let colors = {} | |
for (let i=0; i<""" | |
+ str(num_res) | |
+ """;i++){ | |
if (selectedResidues.includes(i)){ | |
colors[i]="hotpink" | |
}else{ | |
colors[i]="lightgray" | |
}} | |
let colorFixedSidechain = function(atom){ | |
if (selectedResidues.includes(atom.resi)){ | |
return "hotpink" | |
}else if (atom.elem == "O"){ | |
return "red" | |
}else if (atom.elem == "N"){ | |
return "blue" | |
}else if (atom.elem == "S"){ | |
return "yellow" | |
}else{ | |
return "lightgray" | |
} | |
} | |
$(document).ready(function () { | |
element = $("#container"); | |
config = { backgroundColor: "white" }; | |
//viewer.ui.initiateUI(); | |
drawStructures(currentIndex, selectedResidues) | |
$("#sidechain").change(function () { | |
if (this.checked) { | |
BB = ["C", "O", "N"] | |
if ($("#startstructure").prop("checked")) { | |
viewer.getModel(0).setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: { colorfunc: colorAlpha }}); | |
viewer.getModel(1).setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorfunc:colorFixedSidechain, radius: 0.3}, cartoon: {colorscheme:{prop:"resi",map:colors} }}); | |
}else{ | |
viewer.getModel(0).setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: { colorfunc: colorAlpha }}); | |
viewer.getModel(1).setStyle(); | |
} | |
viewer.render() | |
} else { | |
if ($("#startstructure").prop("checked")) { | |
viewer.getModel(0).setStyle({cartoon: { colorfunc: colorAlpha }}); | |
viewer.getModel(1).setStyle({cartoon: {colorscheme:{prop:"resi",map:colors} }}); | |
}else{ | |
viewer.getModel(0).setStyle({cartoon: { colorfunc: colorAlpha }}); | |
viewer.getModel(1).setStyle(); | |
} | |
viewer.render() | |
} | |
}); | |
$("#seq").change(function () { | |
drawStructures(this.value, selectedResidues) | |
currentIndex = this.value | |
$("#sidechain").prop( "checked", false ); | |
$("#startstructure").prop( "checked", true ); | |
}); | |
$("#startstructure").change(function () { | |
if (this.checked) { | |
$("#sidechain").prop( "checked", false ); | |
viewer.getModel(1).setStyle({},{cartoon: {colorscheme:{prop:"resi",map:colors} } }) | |
viewer.getModel(0).setStyle({}, { cartoon: { colorfunc: colorAlpha } }); | |
viewer.render() | |
} else { | |
$("#sidechain").prop( "checked", false ); | |
viewer.getModel(1).setStyle({},{}) | |
viewer.getModel(0).setStyle({}, { cartoon: { colorfunc: colorAlpha } }); | |
viewer.render() | |
} | |
}); | |
$("#download").click(function () { | |
download("outputs/esm_fold_prediction_"+currentIndex+".pdb", data[0]); | |
}) | |
}); | |
function download(filename, text) { | |
var element = document.createElement("a"); | |
element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(text)); | |
element.setAttribute("download", filename); | |
element.style.display = "none"; | |
document.body.appendChild(element); | |
element.click(); | |
document.body.removeChild(element); | |
} | |
</script> | |
</body></html>""" | |
) | |
return f"""<iframe style="width: 100%; height: 1300px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
def set_examples(example): | |
( | |
label, | |
inp, | |
designed_chain, | |
fixed_chain, | |
homomer, | |
num_seqs, | |
sampling_temp, | |
atomsel, | |
) = example | |
return [ | |
label, | |
inp, | |
designed_chain, | |
fixed_chain, | |
homomer, | |
gr.Slider.update(value=num_seqs), | |
gr.Radio.update(value=sampling_temp), | |
atomsel, | |
] | |
import hashlib | |
def get_esmfold(sequence): | |
headers = { | |
"Content-Type": "application/x-www-form-urlencoded", | |
} | |
sequence = sequence.replace("/", ":") | |
filename = "outputs/" + hashlib.md5(str.encode(sequence)).hexdigest() + ".pdb" | |
if not os.path.exists(filename): | |
response = requests.post( | |
"https://api.esmatlas.com/foldSequence/v1/pdb/", | |
headers=headers, | |
data=sequence, | |
verify=False | |
) | |
name = sequence[:3] + sequence[-3:] | |
pdb_string = response.content.decode("utf-8") | |
with open(filename, "w+") as f: | |
f.write(pdb_string) | |
print("retrieved prediction", filename) | |
else: | |
print("prediction already on disk") | |
return filename | |
proteinMPNN = gr.Blocks() | |
with proteinMPNN: | |
gr.Markdown("# ProteinMPNN + ESMFold") | |
gr.Markdown( | |
"""This model takes as input a protein structure and based on its backbone predicts new sequences that will fold into that backbone. | |
It will then run [ESMFold](https://esmatlas.com/about) by MetaAI on the predicted structures and align the predicted structure for the designed sequence with the original backbone. | |
**Note, there is a 400 residue limit in this version and multimeric structures can only be predicted locally. Follow, [README](https://huggingface.co/spaces/simonduerr/ProteinMPNNESM/blob/main/README.md) for instructions on how to run locally.** | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Input"): | |
inp = gr.Textbox( | |
placeholder="PDB Code or upload file below", label="Input structure" | |
) | |
file = gr.File(file_count="single") | |
with gr.TabItem("Settings"): | |
with gr.Row(): | |
designed_chain = gr.Textbox(value="A", label="Designed chain") | |
fixed_chain = gr.Textbox( | |
placeholder="Use commas to fix multiple chains", label="Fixed chain" | |
) | |
with gr.Row(): | |
num_seqs = gr.Slider( | |
minimum=1, maximum=15, value=1, step=1, label="Number of sequences" | |
) | |
sampling_temp = gr.Radio( | |
choices=[0.1, 0.15, 0.2, 0.25, 0.3], | |
value=0.1, | |
label="Sampling temperature", | |
) | |
gr.Markdown( | |
""" Sampling temperature for amino acids, `T=0.0` means taking argmax, `T>>1.0` means sample randomly. Suggested values `0.1, 0.15, 0.2, 0.25, 0.3`. Higher values will lead to more diversity. | |
""" | |
) | |
with gr.Row(): | |
model_name = gr.Dropdown( | |
choices=[ | |
"v_48_002", | |
"v_48_010", | |
"v_48_020", | |
"v_48_030", | |
], | |
label="Model", | |
value="v_48_020", | |
) | |
backbone_noise = gr.Dropdown( | |
choices=[0, 0.02, 0.10, 0.20, 0.30], label="Backbone noise", value=0, | |
) | |
with gr.Row(): | |
homomer = gr.Checkbox(value=False, label="Homomer?") | |
gr.Markdown( | |
"for correct symmetric tying lenghts of homomer chains should be the same" | |
) | |
gr.Markdown("## Fixed positions") | |
gr.Markdown( | |
"""You can fix important positions in the protein. Resid should be specified with the same numbering as in the input pdb file. The fixed residues will be highlighted in the output. | |
The [VMD selection](http://www.ks.uiuc.edu/Research/vmd/vmd-1.9.2/ug/node89.html) synthax is used. You can also select based on ligands or chains in the input structure to specify interfaces to be fixed. | |
- <code>within 5 of resid 94</code> All residues that have >1 atom closer than 5 Å to any atom of residue 94 | |
- <code>name CA and within 5 of resid 94</code> All residues that have CA atom closer than 5 Å to any atom of residue 94 | |
- <code>resid 94 96 119</code> Residues 94, 94 and 119 | |
- <code>within 5 of resname ZN</code> All residues with any atom <5 Å of zinc ion | |
- <code>chain A and within 5 of chain B </code> All residues of chain A that are part of the interface with chain B | |
- <code>protein and within 5 of nucleic </code> All residues that bind to DNA (if present in structure) | |
- <code>not (chain A and within 5 of chain B) </code> only modify residues that are in the interface with the fixed chain, not further away | |
- <code>chain A or (chain B and sasa < 20) </code> Keep chain A and all core residues fixeds | |
- <code>pLDDT >70 </code> Redesign all residues with low pLDDT | |
Note that <code>sasa</code> and <code>pLDDT</code> selectors modify default VMD behavior. SASA is calculated using moleculekit and written to the mass attribute. Selections based on mass do not work. | |
pLDDT is an alias for beta, it only works correctly with structures that contain the appropriate values in the beta column of the PDB file. """ | |
) | |
atomsel = gr.Textbox( | |
placeholder="Specify atom selection ", label="Fixed positions", | |
api_name= "fixed_positions" | |
) | |
btn = gr.Button("Run") | |
label = gr.Textbox(label="Label", visible=False) | |
samples = [["Monomer design", "6MRR", "A", "", False, 2, 0.1, ""]] | |
if is_local: | |
samples.extend( | |
[ | |
["Homomer design", "1O91", "A,B,C", "", True, 2, 0.1, ""], | |
[ | |
"Redesign of Homomer to Heteromer", | |
"3HTN", | |
"A,B", | |
"C", | |
False, | |
2, | |
0.1, | |
"", | |
], | |
[ | |
"Redesign of MID1 scaffold keeping binding site fixed", | |
"3V1C", | |
"A,B", | |
"", | |
False, | |
2, | |
0.1, | |
"within 5 of resname ZN", | |
], | |
[ | |
"Redesign of DNA binding protein", | |
"3JRD", | |
"A,B", | |
"", | |
False, | |
2, | |
0.1, | |
"within 8 of nucleic", | |
], | |
[ | |
"Surface Redesign of miniprotein", | |
"7JZM", | |
"A,B", | |
"", | |
False, | |
2, | |
0.1, | |
"chain B or (chain A and sasa < 20)", | |
], | |
] | |
) | |
examples = gr.Dataset( | |
components=[ | |
label, | |
inp, | |
designed_chain, | |
fixed_chain, | |
homomer, | |
num_seqs, | |
sampling_temp, | |
atomsel, | |
], | |
samples=samples, | |
) | |
gr.Markdown("# Output") | |
with gr.Tabs(): | |
with gr.TabItem("Designed sequences"): | |
chosen_seq = gr.Dropdown( | |
choices=[], | |
label="Select a sequence for validation", | |
) | |
mol = gr.HTML() | |
out = gr.Textbox(label="Fasta Output") | |
with gr.TabItem("Amino acid probabilities"): | |
plot = gr.Plot() | |
all_log_probs = gr.File(visible=False) | |
with gr.TabItem("T adjusted probabilities"): | |
gr.Markdown("Sampling temperature adjusted amino acid probabilties") | |
plot_tadjusted = gr.Plot() | |
all_probs = gr.File(visible=False) | |
tempFile = gr.Variable() | |
selectedResidues = gr.Variable() | |
seq_dict = gr.Variable() | |
btn.click( | |
fn=update, | |
inputs=[ | |
inp, | |
file, | |
designed_chain, | |
fixed_chain, | |
homomer, | |
num_seqs, | |
sampling_temp, | |
model_name, | |
backbone_noise, | |
atomsel, | |
], | |
outputs=[ | |
out, | |
plot, | |
plot_tadjusted, | |
all_log_probs, | |
all_probs, | |
tempFile, | |
chosen_seq, | |
selectedResidues, | |
seq_dict, | |
mol, | |
], | |
api_name = "proteinmpnn" | |
) | |
chosen_seq.change( | |
updateseq, | |
inputs=[chosen_seq, seq_dict, tempFile, selectedResidues], | |
outputs=mol, | |
) | |
examples.click(fn=set_examples, inputs=examples, outputs=examples.components) | |
gr.Markdown( | |
"""Citation: **Robust deep learning based protein sequence design using ProteinMPNN** <br> | |
Justas Dauparas, Ivan Anishchenko, Nathaniel Bennett, Hua Bai, Robert J. Ragotte, Lukas F. Milles, Basile I. M. Wicky, Alexis Courbet, Robbert J. de Haas, Neville Bethel, Philip J. Y. Leung, Timothy F. Huddy, Sam Pellock, Doug Tischer, Frederick Chan, Brian Koepnick, Hannah Nguyen, Alex Kang, Banumathi Sankaran, Asim Bera, Neil P. King, David Baker <br> | |
Science Vol 378, Issue 6615, pp. 49 -56; doi: [10.1126/science.add2187](https://doi.org/10.1126/science.add2187 <br><br> Server built by [@simonduerr](https://twitter.com/simonduerr) and hosted by Huggingface""" | |
) | |
proteinMPNN.launch() | |