xuyingli
Update app.py
1c470a9
raw
history blame
22.2 kB
import streamlit as st
import torch
import esm
import matplotlib.pyplot as plt
from myscaledb import Client
import random
from collections import Counter
from tqdm import tqdm
from statistics import mean
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from stmol import *
import py3Dmol
# from streamlit_3Dmol import component_3dmol
import esm
import scipy
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.svm import SVC, SVR
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression, SGDRegressor
from sklearn.pipeline import Pipeline
from streamlit.components.v1 import html
def init_esm():
msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_transformer = msa_transformer.eval()
return msa_transformer, msa_transformer_alphabet
@st.experimental_singleton(show_spinner=False)
def init_db():
""" Initialize the Database Connection
Returns:
meta_field: Meta field that records if an image is viewed
client: Database connection object
"""
client = Client(
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
# We can check if the connection is alive
assert client.is_alive()
meta_field = {}
return meta_field, Client
def perdict_contact_visualization(seq, model, batch_converter):
data = [
("protein1", seq),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
# Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[12], return_contacts=True)
token_representations = results["representations"][12]
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))
# Look at the unsupervised self-attention map contact predictions
for (_, seq), attention_contacts in zip(data, results["contacts"]):
fig, ax = plt.subplots()
ax.matshow(attention_contacts[: len(seq), : len(seq)])
fig.suptitle(seq)
# fig.set_facecolor('black')
return fig
def visualize_3D_Coordinates(coords):
xs = []
ys = []
zs = []
for i in coords:
xs.append(i[0])
ys.append(i[1])
zs.append(i[2])
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')
ax.set_title('3D coordinates of $C_{b}$ backbone structure')
N = len(coords)
for i in range(len(coords) - 1):
ax.plot(
xs[i:i+2], ys[i:i+2], zs[i:i+2],
color=plt.cm.viridis(i/N),
marker='o'
)
return fig
def esm_search(model, sequnce, batch_converter,top_k=5):
data = [
("protein1", sequnce),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
# Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[12], return_contacts=True)
token_representations = results["representations"][12]
token_list = token_representations.tolist()[0][0][0]
client = Client(
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768")
result_temp_seq = []
for i in result:
# result_temp_coords = i['seq']
result_temp_seq.append(i['seq'])
result_temp_seq = list(set(result_temp_seq))
return result_temp_seq
def KNN_search(sequence):
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()
data = [("protein1", sequence),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]
token_list = token_representations.tolist()[0][0]
print(token_list)
client = Client(
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
result = client.fetch("SELECT activity, distance('topK=10')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer")
result_temp_activity = []
for i in result:
# print(result_temp_seq)
result_temp_activity.append(i['activity'])
res_1 = sum(result_temp_activity)/len(result_temp_activity)
return res_1
def train_test_split_PCA(dataset):
ys = []
Xs = []
FASTA_PATH = '/root/xuying_experiments/esm-main/P62593.fasta'
EMB_PATH = '/root/xuying_experiments/esm-main/P62593_reprs'
for header, _seq in esm.data.read_fasta(FASTA_PATH):
scaled_effect = header.split('|')[-1]
ys.append(float(scaled_effect))
fn = f'{EMB_PATH}/{header}.pt'
embs = torch.load(fn)
Xs.append(embs['mean_representations'][34])
Xs = torch.stack(Xs, dim=0).numpy()
train_size = 0.8
Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42)
return Xs_train, Xs_test, ys_train, ys_test
def PCA_visual(Xs_train):
num_pca_components = 60
pca = PCA(num_pca_components)
Xs_train_pca = pca.fit_transform(Xs_train)
fig_dims = (4, 4)
fig, ax = plt.subplots(figsize=fig_dims)
ax.set_title('Visualize Embeddings')
sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.')
ax.set_xlabel('PCA first principal component')
ax.set_ylabel('PCA second principal component')
plt.colorbar(sc, label='Variant Effect')
return fig
def KNN_trainings(Xs_train, Xs_test, ys_train, ys_test):
num_pca_components = 60
knn_grid = [
{
'model': [KNeighborsRegressor()],
'model__n_neighbors': [5, 10],
'model__weights': ['uniform', 'distance'],
'model__algorithm': ['ball_tree', 'kd_tree', 'brute'],
'model__leaf_size' : [15, 30],
'model__p' : [1, 2],
}]
cls_list = [KNeighborsRegressor]
param_grid_list = [knn_grid]
pipe = Pipeline(
steps = (
('pca', PCA(num_pca_components)),
('model', KNeighborsRegressor())
)
)
result_list = []
grid_list = []
for cls_name, param_grid in zip(cls_list, param_grid_list):
print(cls_name)
grid = GridSearchCV(
estimator = pipe,
param_grid = param_grid,
scoring = 'r2',
verbose = 1,
n_jobs = -1 # use all available cores
)
grid.fit(Xs_train, ys_train)
# print(Xs_train, ys_train)
result_list.append(pd.DataFrame.from_dict(grid.cv_results_))
grid_list.append(grid)
dataframe = pd.DataFrame(result_list[0].sort_values('rank_test_score')[:5])
return dataframe[['param_model','params','param_model__algorithm','mean_test_score','rank_test_score']]
st.markdown("""
<link
rel="stylesheet"
href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap"
/>
""", unsafe_allow_html=True)
messages = [
f"""
Evolutionary-scale prediction of atomic level protein structure
ESM is a high-capacity Transformer trained with protein sequences \
as input. After training, the secondary and tertiary structure, \
function, homology and other information of the protein are in the feature representation output by the model.\
Check out https://esmatlas.com/ for more information.
We have 120k proteins features stored in our database.
The app uses the [MyScale](MyScale Database) to store and query protein sequence
using vector search.
"""
]
@st.experimental_singleton(show_spinner=False)
def init_random_query():
xq = np.random.rand(DIMS).tolist()
return xq, xq.copy()
with st.spinner("Connecting DB..."):
st.session_state.meta, client = init_db()
with st.spinner("Loading Models..."):
# Initialize SAGE model
if 'xq' not in st.session_state:
model, alphabet = init_esm()
batch_converter = alphabet.get_batch_converter()
st.session_state['batch'] = batch_converter
st.session_state.query_num = 0
if 'xq' not in st.session_state:
# If it's a fresh start
if st.session_state.query_num < len(messages):
msg = messages[0]
else:
msg = messages[-1]
with st.container():
st.title("Evolutionary Scale Modeling")
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()]
start[0].info(msg)
option = st.selectbox('Application options', ('self-contact prediction', 'search the database', 'activity prediction','PDB viewer'))
st.session_state.db_name_ref = 'default.esm_protein'
if option == 'self-contact prediction':
sequence = st.text_input('protein sequence', '')
if st.button('Cas9 Enzyme'):
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV'
elif st.button('PETase'):
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ'
if sequence:
st.write('')
start[2] = st.pyplot(perdict_contact_visualization(sequence, model, batch_converter))
expander = st.expander("See explanation")
expander.text("""Contact prediction is based on a logistic regression over the model's attention maps. \
This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners.
(Rao et al. 2020) The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.""")
st.session_state['xq'] = model
elif option == 'search the database':
sequence = st.text_input('protein sequence', '')
st.write('Try an example:')
if st.button('Cas9 Enzyme'):
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV'
elif st.button('PETase'):
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ'
if sequence:
st.write('you have entered: ', sequence)
result_temp_seq = esm_search(model, sequence, esm_search,top_k=5)
st.text('search result: ')
# tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"])
if st.button(result_temp_seq[0]):
print(result_temp_seq[0])
elif st.button(result_temp_seq[1]):
print(result_temp_seq[1])
elif st.button(result_temp_seq[2]):
print(result_temp_seq[2])
elif st.button(result_temp_seq[3]):
print(result_temp_seq[3])
elif st.button(result_temp_seq[4]):
print(result_temp_seq[4])
start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure)
st.session_state['xq'] = model
elif option == 'activity prediction':
st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.')
sequence = st.text_input('protein sequence', '')
st.write('Try an example:')
if st.button('Cas9 Enzyme'):
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV'
elif st.button('PETase'):
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ'
elif option == 'PDB viewer':
id_PDB = st.text_input('enter PDB ID', '')
residues_marker = st.text_input('residues class', '')
if residues_marker:
start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker]))
else:
start[3] = showmol(render_pdb(id = id_PDB))
st.session_state['xq'] = model
else:
if st.session_state.query_num < len(messages):
msg = messages[0]
else:
msg = messages[-1]
with st.container():
st.title("Evolutionary Scale Modeling")
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()]
start[0].info(msg)
option = st.selectbox('Application options', ('self-contact prediction', 'search the database', 'activity prediction','PDB viewer'))
st.session_state.db_name_ref = 'default.esm_protein'
if option == 'self-contact prediction':
sequence = st.text_input('protein sequence', '')
if st.button('Cas9 Enzyme'):
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV'
elif st.button('PETase'):
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ'
if sequence:
st.write('you have entered: ',sequence)
start[2] = st.pyplot(perdict_contact_visualization(sequence, st.session_state['xq'], st.session_state['batch']))
expander = st.expander("See explanation")
expander.markdown(
"""<span style="word-wrap:break-word;">Contact prediction is based on a logistic regression over the model's attention maps. This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020)The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.</span>
""", unsafe_allow_html=True)
elif option == 'search the database':
sequence = st.text_input('protein sequence', '')
st.write('Try an example:')
if st.button('Cas9 Enzyme'):
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV'
elif st.button('PETase'):
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ'
if sequence:
st.write('you have entered: ', sequence)
result_temp_seq = esm_search(st.session_state['xq'], sequence, st.session_state['batch'] ,top_k=10)
st.text('search result (top 5): ')
# tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"])
tab1, tab2, tab3 , tab4, tab5 = st.tabs(['1','2','3','4','5'])
# option2 = st.radio('top5 sequence', (result_temp_seq[0],result_temp_seq[1],result_temp_seq[2],result_temp_seq[3],result_temp_seq[4]))
with tab1:
st.write(result_temp_seq[0])
import random
# print(random.randint(0,9))
prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS']
# protein=st.selectbox('select protein',prot_list)
protein = prot_str[random.randint(14,18)]
xyzview = py3Dmol.view(query='pdb:'+protein)
xyzview.setStyle({'stick':{'color':'spectrum'}})
showmol(xyzview, height = 500,width=800)
# st.write(result_temp_seq[4])
with tab2:
import random
# print(random.randint(0,9))
st.write(result_temp_seq[1])
prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS']
# protein=st.selectbox('select protein',prot_list)
protein = prot_str[random.randint(0,4)]
xyzview = py3Dmol.view(query='pdb:'+protein)
xyzview.setStyle({'stick':{'color':'spectrum'}})
showmol(xyzview, height = 500,width=800)
with tab3:
st.write(result_temp_seq[2])
prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS']
# protein=st.selectbox('select protein',prot_list)
protein = prot_str[random.randint(4,8)]
xyzview = py3Dmol.view(query='pdb:'+protein)
xyzview.setStyle({'stick':{'color':'spectrum'}})
showmol(xyzview, height = 500,width=800)
with tab4:
st.write(result_temp_seq[3])
prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS']
# protein=st.selectbox('select protein',prot_list)
protein = prot_str[random.randint(4,8)]
xyzview = py3Dmol.view(query='pdb:'+protein)
xyzview.setStyle({'stick':{'color':'spectrum'}})
showmol(xyzview, height = 500,width=800)
with tab5:
st.write(result_temp_seq[4])
prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS']
# protein=st.selectbox('select protein',prot_list)
protein = prot_str[random.randint(4,8)]
xyzview = py3Dmol.view(query='pdb:'+protein)
xyzview.setStyle({'stick':{'color':'spectrum'}})
showmol(xyzview, height = 500,width=800)
elif option == 'activity prediction':
st.markdown('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.')
# st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.')
sequence = st.text_input('protein sequence', '')
st.write('Try an example:')
if st.button('Cas9 Enzyme'):
sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV'
elif st.button('PETase'):
sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ'
if sequence:
st.write('you have entered: ',sequence)
res_knn = KNN_search(sequence)
st.subheader('KNN predictor result')
start[2] = st.markdown("Activity prediction: " + str(res_knn))
elif option == 'PDB viewer':
id_PDB = st.text_input('enter PDB ID', '')
residues_marker = st.text_input('residues class', '')
st.write('Try an example:')
if st.button('PDB ID: 1A2C / residues class: ALA'):
id_PDB = '1A2C'
residues_marker = 'ALA'
st.subheader('PDB viewer')
if residues_marker:
start[7] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker]))
else:
start[7] = showmol(render_pdb(id = id_PDB))
expander = st.expander("See explanation")
expander.markdown("""
A PDB ID is a unique 4-character code for each entry in the Protein Data Bank. The first character must be a number between 1 and 9, and the remaining three characters can be letters or numbers.
see https://www.rcsb.org/ for more information.
""")