Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import esm | |
import requests | |
import matplotlib.pyplot as plt | |
from myscaledb import Client | |
import random | |
from collections import Counter | |
from tqdm import tqdm | |
from statistics import mean | |
import biotite.structure.io as bsio | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
from stmol import * | |
import py3Dmol | |
# from streamlit_3Dmol import component_3dmol | |
import scipy | |
from sklearn.model_selection import GridSearchCV, train_test_split | |
from sklearn.decomposition import PCA | |
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor | |
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | |
from sklearn.linear_model import LogisticRegression, SGDRegressor | |
from sklearn.pipeline import Pipeline | |
from streamlit.components.v1 import html | |
def init_esm(): | |
msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() | |
msa_transformer = msa_transformer.eval() | |
return msa_transformer, msa_transformer_alphabet | |
def init_db(): | |
""" Initialize the Database Connection | |
Returns: | |
meta_field: Meta field that records if an image is viewed | |
client: Database connection object | |
""" | |
client = Client( | |
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) | |
# We can check if the connection is alive | |
assert client.is_alive() | |
meta_field = {} | |
return meta_field, Client | |
def perdict_contact_visualization(seq, model, batch_converter): | |
data = [ | |
("protein1", seq), | |
] | |
batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
# Extract per-residue representations (on CPU) | |
with torch.no_grad(): | |
results = model(batch_tokens, repr_layers=[12], return_contacts=True) | |
token_representations = results["representations"][12] | |
# Generate per-sequence representations via averaging | |
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. | |
sequence_representations = [] | |
for i, (_, seq) in enumerate(data): | |
sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0)) | |
# Look at the unsupervised self-attention map contact predictions | |
for (_, seq), attention_contacts in zip(data, results["contacts"]): | |
fig, ax = plt.subplots() | |
ax.matshow(attention_contacts[: len(seq), : len(seq)]) | |
# fig.set_facecolor('black') | |
return fig | |
def visualize_3D_Coordinates(coords): | |
xs = [] | |
ys = [] | |
zs = [] | |
for i in coords: | |
xs.append(i[0]) | |
ys.append(i[1]) | |
zs.append(i[2]) | |
fig = plt.figure(figsize=(10,10)) | |
ax = fig.add_subplot(111, projection='3d') | |
ax.set_title('3D coordinates of $C_{b}$ backbone structure') | |
N = len(coords) | |
for i in range(len(coords) - 1): | |
ax.plot( | |
xs[i:i+2], ys[i:i+2], zs[i:i+2], | |
color=plt.cm.viridis(i/N), | |
marker='o' | |
) | |
return fig | |
def render_mol(pdb): | |
pdbview = py3Dmol.view() | |
pdbview.addModel(pdb,'pdb') | |
pdbview.setStyle({'cartoon':{'color':'spectrum'}}) | |
pdbview.setBackgroundColor('white')#('0xeeeeee') | |
pdbview.zoomTo() | |
pdbview.zoom(2, 800) | |
pdbview.spin(True) | |
showmol(pdbview, height = 500,width=800) | |
def esm_search(model, sequnce, batch_converter,top_k=5): | |
data = [ | |
("protein1", sequnce), | |
] | |
batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
# Extract per-residue representations (on CPU) | |
with torch.no_grad(): | |
results = model(batch_tokens, repr_layers=[12], return_contacts=True) | |
token_representations = results["representations"][12] | |
token_list = token_representations.tolist()[0][0][0] | |
client = Client( | |
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) | |
result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768") | |
result_temp_seq = [] | |
for i in result: | |
# result_temp_coords = i['seq'] | |
result_temp_seq.append(i['seq']) | |
result_temp_seq = list(set(result_temp_seq)) | |
return result_temp_seq | |
def show_protein_structure(sequence): | |
headers = { | |
'Content-Type': 'application/x-www-form-urlencoded', | |
} | |
response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) | |
name = sequence[:3] + sequence[-3:] | |
pdb_string = response.content.decode('utf-8') | |
with open('predicted.pdb', 'w') as f: | |
f.write(pdb_string) | |
struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) | |
b_value = round(struct.b_factor.mean(), 4) | |
render_mol(pdb_string) | |
def KNN_search(sequence): | |
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() | |
batch_converter = alphabet.get_batch_converter() | |
model.eval() | |
data = [("protein1", sequence), | |
] | |
batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) | |
with torch.no_grad(): | |
results = model(batch_tokens, repr_layers=[33], return_contacts=True) | |
token_representations = results["representations"][33] | |
token_list = token_representations.tolist()[0][0] | |
print(token_list) | |
client = Client( | |
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) | |
result = client.fetch("SELECT activity, distance('topK=10')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer") | |
result_temp_activity = [] | |
for i in result: | |
# print(result_temp_seq) | |
result_temp_activity.append(i['activity']) | |
res_1 = sum(result_temp_activity)/len(result_temp_activity) | |
return res_1 | |
def train_test_split_PCA(dataset): | |
ys = [] | |
Xs = [] | |
FASTA_PATH = '/root/xuying_experiments/esm-main/P62593.fasta' | |
EMB_PATH = '/root/xuying_experiments/esm-main/P62593_reprs' | |
for header, _seq in esm.data.read_fasta(FASTA_PATH): | |
scaled_effect = header.split('|')[-1] | |
ys.append(float(scaled_effect)) | |
fn = f'{EMB_PATH}/{header}.pt' | |
embs = torch.load(fn) | |
Xs.append(embs['mean_representations'][34]) | |
Xs = torch.stack(Xs, dim=0).numpy() | |
train_size = 0.8 | |
Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42) | |
return Xs_train, Xs_test, ys_train, ys_test | |
def PCA_visual(Xs_train): | |
num_pca_components = 60 | |
pca = PCA(num_pca_components) | |
Xs_train_pca = pca.fit_transform(Xs_train) | |
fig_dims = (4, 4) | |
fig, ax = plt.subplots(figsize=fig_dims) | |
ax.set_title('Visualize Embeddings') | |
sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.') | |
ax.set_xlabel('PCA first principal component') | |
ax.set_ylabel('PCA second principal component') | |
plt.colorbar(sc, label='Variant Effect') | |
return fig | |
def KNN_trainings(Xs_train, Xs_test, ys_train, ys_test): | |
num_pca_components = 60 | |
knn_grid = [ | |
{ | |
'model': [KNeighborsRegressor()], | |
'model__n_neighbors': [5, 10], | |
'model__weights': ['uniform', 'distance'], | |
'model__algorithm': ['ball_tree', 'kd_tree', 'brute'], | |
'model__leaf_size' : [15, 30], | |
'model__p' : [1, 2], | |
}] | |
cls_list = [KNeighborsRegressor] | |
param_grid_list = [knn_grid] | |
pipe = Pipeline( | |
steps = ( | |
('pca', PCA(num_pca_components)), | |
('model', KNeighborsRegressor()) | |
) | |
) | |
result_list = [] | |
grid_list = [] | |
for cls_name, param_grid in zip(cls_list, param_grid_list): | |
print(cls_name) | |
grid = GridSearchCV( | |
estimator = pipe, | |
param_grid = param_grid, | |
scoring = 'r2', | |
verbose = 1, | |
n_jobs = -1 # use all available cores | |
) | |
grid.fit(Xs_train, ys_train) | |
# print(Xs_train, ys_train) | |
result_list.append(pd.DataFrame.from_dict(grid.cv_results_)) | |
grid_list.append(grid) | |
dataframe = pd.DataFrame(result_list[0].sort_values('rank_test_score')[:5]) | |
return dataframe[['param_model','params','param_model__algorithm','mean_test_score','rank_test_score']] | |
st.markdown(""" | |
<link | |
rel="stylesheet" | |
href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap" | |
/> | |
""", unsafe_allow_html=True) | |
messages = [ | |
f""" | |
Evolutionary-scale prediction of atomic level protein structure | |
ESM is a high-capacity Transformer trained with protein sequences \ | |
as input. After training, the secondary and tertiary structure, \ | |
function, homology and other information of the protein are in the feature representation output by the model.\ | |
Check out https://esmatlas.com/ for more information. | |
We have 120k proteins features stored in our database. | |
The app uses MyScale to store and query protein sequence | |
using vector search. | |
""" | |
] | |
def init_random_query(): | |
xq = np.random.rand(DIMS).tolist() | |
return xq, xq.copy() | |
with st.spinner("Connecting DB..."): | |
st.session_state.meta, client = init_db() | |
with st.spinner("Loading Models..."): | |
# Initialize SAGE model | |
if 'xq' not in st.session_state: | |
model, alphabet = init_esm() | |
batch_converter = alphabet.get_batch_converter() | |
st.session_state['batch'] = batch_converter | |
st.session_state.query_num = 0 | |
if 'xq' not in st.session_state: | |
# If it's a fresh start | |
if st.session_state.query_num < len(messages): | |
msg = messages[0] | |
else: | |
msg = messages[-1] | |
with st.container(): | |
st.title("Evolutionary Scale Modeling") | |
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] | |
start[0].info(msg) | |
function_list = ('self-contact prediction', | |
'search the database for similar proteins', | |
'activity prediction with similar proteins', | |
'PDB viewer') | |
option = st.selectbox('Application options', function_list) | |
st.session_state.db_name_ref = 'default.esm_protein' | |
if option == function_list[0]: | |
sequence = st.text_input('protein sequence', '') | |
if st.button('Cas9 Enzyme'): | |
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
elif st.button('PETase'): | |
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
if sequence: | |
st.write('') | |
start[2] = st.pyplot(perdict_contact_visualization(sequence, model, batch_converter)) | |
expander = st.expander("See explanation") | |
expander.text("""Contact prediction is based on a logistic regression over the model's attention maps. \ | |
This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. | |
(Rao et al. 2020) The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.""") | |
st.session_state['xq'] = model | |
elif option == function_list[1]: | |
sequence = st.text_input('protein sequence', '') | |
st.write('Try an example:') | |
if st.button('Cas9 Enzyme'): | |
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
elif st.button('PETase'): | |
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
if sequence: | |
st.write('you have entered: ', sequence) | |
result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) | |
st.text('search result: ') | |
# tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) | |
if st.button(result_temp_seq[0]): | |
print(result_temp_seq[0]) | |
elif st.button(result_temp_seq[1]): | |
print(result_temp_seq[1]) | |
elif st.button(result_temp_seq[2]): | |
print(result_temp_seq[2]) | |
elif st.button(result_temp_seq[3]): | |
print(result_temp_seq[3]) | |
elif st.button(result_temp_seq[4]): | |
print(result_temp_seq[4]) | |
start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) | |
st.session_state['xq'] = model | |
elif option == function_list[2]: | |
st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') | |
sequence = st.text_input('protein sequence', '') | |
st.write('Try an example:') | |
if st.button('Cas9 Enzyme'): | |
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
elif st.button('PETase'): | |
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
elif option == function_list[3]: | |
id_PDB = st.text_input('enter PDB ID', '') | |
residues_marker = st.text_input('residues class', '') | |
if residues_marker: | |
start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) | |
else: | |
start[3] = showmol(render_pdb(id = id_PDB)) | |
st.session_state['xq'] = model | |
else: | |
if st.session_state.query_num < len(messages): | |
msg = messages[0] | |
else: | |
msg = messages[-1] | |
with st.container(): | |
st.title("Evolutionary Scale Modeling") | |
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] | |
start[0].info(msg) | |
option = st.selectbox('Application options', ('self-contact prediction', 'search the database', 'activity prediction','PDB viewer')) | |
st.session_state.db_name_ref = 'default.esm_protein' | |
if option == 'self-contact prediction': | |
sequence = st.text_input('protein sequence', '') | |
if st.button('Cas9 Enzyme'): | |
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
elif st.button('PETase'): | |
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
if sequence: | |
st.write('you have entered: ',sequence) | |
start[2] = st.pyplot(perdict_contact_visualization(sequence, st.session_state['xq'], st.session_state['batch'])) | |
expander = st.expander("See explanation") | |
expander.markdown( | |
"""<span style="word-wrap:break-word;">Contact prediction is based on a logistic regression over the model's attention maps. This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020)The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.</span> | |
""", unsafe_allow_html=True) | |
elif option == 'search the database': | |
sequence = st.text_input('protein sequence', '') | |
st.write('Try an example:') | |
if st.button('Cas9 Enzyme'): | |
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
elif st.button('PETase'): | |
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
if sequence: | |
st.write('you have entered: ', sequence) | |
result_temp_seq = esm_search(st.session_state['xq'], sequence, st.session_state['batch'] ,top_k=10) | |
st.text('search result (top 5): ') | |
# tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) | |
tab1, tab2, tab3 , tab4, tab5 = st.tabs(['1','2','3','4','5']) | |
with tab1: | |
st.write(result_temp_seq[0]) | |
show_protein_structure(result_temp_seq[0]) | |
with tab2: | |
st.write(result_temp_seq[1]) | |
show_protein_structure(result_temp_seq[1]) | |
with tab3: | |
st.write(result_temp_seq[2]) | |
show_protein_structure(result_temp_seq[2]) | |
with tab4: | |
st.write(result_temp_seq[3]) | |
show_protein_structure(result_temp_seq[3]) | |
with tab5: | |
st.write(result_temp_seq[4]) | |
show_protein_structure(result_temp_seq[4]) | |
elif option == 'activity prediction': | |
st.markdown('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') | |
# st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') | |
sequence = st.text_input('protein sequence', '') | |
st.write('Try an example:') | |
if st.button('Cas9 Enzyme'): | |
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
elif st.button('PETase'): | |
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
if sequence: | |
st.write('you have entered: ',sequence) | |
res_knn = KNN_search(sequence) | |
st.subheader('KNN predictor result') | |
start[2] = st.markdown("Activity prediction: " + str(res_knn)) | |
elif option == 'PDB viewer': | |
id_PDB = st.text_input('enter PDB ID', '') | |
residues_marker = st.text_input('residues class', '') | |
st.write('Try an example:') | |
if st.button('PDB ID: 1A2C / residues class: ALA'): | |
id_PDB = '1A2C' | |
residues_marker = 'ALA' | |
st.subheader('PDB viewer') | |
if residues_marker: | |
start[7] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) | |
else: | |
start[7] = showmol(render_pdb(id = id_PDB)) | |
expander = st.expander("See explanation") | |
expander.markdown(""" | |
A PDB ID is a unique 4-character code for each entry in the Protein Data Bank. The first character must be a number between 1 and 9, and the remaining three characters can be letters or numbers. | |
see https://www.rcsb.org/ for more information. | |
""") |