import streamlit as st import torch import esm import matplotlib.pyplot as plt from myscaledb import Client import random from collections import Counter from tqdm import tqdm from statistics import mean import torch import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from stmol import * import py3Dmol # from streamlit_3Dmol import component_3dmol import esm import scipy from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.decomposition import PCA from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor from sklearn.svm import SVC, SVR from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.naive_bayes import GaussianNB from sklearn.linear_model import LogisticRegression, SGDRegressor from sklearn.pipeline import Pipeline from streamlit.components.v1 import html def init_esm(): msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() msa_transformer = msa_transformer.eval() return msa_transformer, msa_transformer_alphabet @st.experimental_singleton(show_spinner=False) def init_db(): """ Initialize the Database Connection Returns: meta_field: Meta field that records if an image is viewed client: Database connection object """ client = Client( url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) # We can check if the connection is alive assert client.is_alive() meta_field = {} return meta_field, Client def perdict_contact_visualization(seq, model, batch_converter): data = [ ("protein1", seq), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) # Extract per-residue representations (on CPU) with torch.no_grad(): results = model(batch_tokens, repr_layers=[12], return_contacts=True) token_representations = results["representations"][12] # Generate per-sequence representations via averaging # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. sequence_representations = [] for i, (_, seq) in enumerate(data): sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0)) # Look at the unsupervised self-attention map contact predictions for (_, seq), attention_contacts in zip(data, results["contacts"]): fig, ax = plt.subplots() ax.matshow(attention_contacts[: len(seq), : len(seq)]) fig.suptitle(seq) # fig.set_facecolor('black') return fig def visualize_3D_Coordinates(coords): xs = [] ys = [] zs = [] for i in coords: xs.append(i[0]) ys.append(i[1]) zs.append(i[2]) fig = plt.figure(figsize=(10,10)) ax = fig.add_subplot(111, projection='3d') ax.set_title('3D coordinates of $C_{b}$ backbone structure') N = len(coords) for i in range(len(coords) - 1): ax.plot( xs[i:i+2], ys[i:i+2], zs[i:i+2], color=plt.cm.viridis(i/N), marker='o' ) return fig def esm_search(model, sequnce, batch_converter,top_k=5): data = [ ("protein1", sequnce), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) # Extract per-residue representations (on CPU) with torch.no_grad(): results = model(batch_tokens, repr_layers=[12], return_contacts=True) token_representations = results["representations"][12] token_list = token_representations.tolist()[0][0][0] client = Client( url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768") result_temp_seq = [] for i in result: # result_temp_coords = i['seq'] result_temp_seq.append(i['seq']) result_temp_seq = list(set(result_temp_seq)) return result_temp_seq def KNN_search(sequence): model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter() model.eval() data = [("protein1", sequence), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) with torch.no_grad(): results = model(batch_tokens, repr_layers=[33], return_contacts=True) token_representations = results["representations"][33] token_list = token_representations.tolist()[0][0] print(token_list) client = Client( url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) result = client.fetch("SELECT activity, distance('topK=10')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer") result_temp_activity = [] for i in result: # print(result_temp_seq) result_temp_activity.append(i['activity']) res_1 = sum(result_temp_activity)/len(result_temp_activity) return res_1 def train_test_split_PCA(dataset): ys = [] Xs = [] FASTA_PATH = '/root/xuying_experiments/esm-main/P62593.fasta' EMB_PATH = '/root/xuying_experiments/esm-main/P62593_reprs' for header, _seq in esm.data.read_fasta(FASTA_PATH): scaled_effect = header.split('|')[-1] ys.append(float(scaled_effect)) fn = f'{EMB_PATH}/{header}.pt' embs = torch.load(fn) Xs.append(embs['mean_representations'][34]) Xs = torch.stack(Xs, dim=0).numpy() train_size = 0.8 Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42) return Xs_train, Xs_test, ys_train, ys_test def PCA_visual(Xs_train): num_pca_components = 60 pca = PCA(num_pca_components) Xs_train_pca = pca.fit_transform(Xs_train) fig_dims = (4, 4) fig, ax = plt.subplots(figsize=fig_dims) ax.set_title('Visualize Embeddings') sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.') ax.set_xlabel('PCA first principal component') ax.set_ylabel('PCA second principal component') plt.colorbar(sc, label='Variant Effect') return fig def KNN_trainings(Xs_train, Xs_test, ys_train, ys_test): num_pca_components = 60 knn_grid = [ { 'model': [KNeighborsRegressor()], 'model__n_neighbors': [5, 10], 'model__weights': ['uniform', 'distance'], 'model__algorithm': ['ball_tree', 'kd_tree', 'brute'], 'model__leaf_size' : [15, 30], 'model__p' : [1, 2], }] cls_list = [KNeighborsRegressor] param_grid_list = [knn_grid] pipe = Pipeline( steps = ( ('pca', PCA(num_pca_components)), ('model', KNeighborsRegressor()) ) ) result_list = [] grid_list = [] for cls_name, param_grid in zip(cls_list, param_grid_list): print(cls_name) grid = GridSearchCV( estimator = pipe, param_grid = param_grid, scoring = 'r2', verbose = 1, n_jobs = -1 # use all available cores ) grid.fit(Xs_train, ys_train) # print(Xs_train, ys_train) result_list.append(pd.DataFrame.from_dict(grid.cv_results_)) grid_list.append(grid) dataframe = pd.DataFrame(result_list[0].sort_values('rank_test_score')[:5]) return dataframe[['param_model','params','param_model__algorithm','mean_test_score','rank_test_score']] st.markdown(""" """, unsafe_allow_html=True) messages = [ f""" Evolutionary-scale prediction of atomic level protein structure ESM is a high-capacity Transformer trained with protein sequences \ as input. After training, the secondary and tertiary structure, \ function, homology and other information of the protein are in the feature representation output by the model.\ Check out https://esmatlas.com/ for more information. We have 120k proteins features stored in our database. The app uses the [MyScale](MyScale Database) to store and query protein sequence using vector search. """ ] @st.experimental_singleton(show_spinner=False) def init_random_query(): xq = np.random.rand(DIMS).tolist() return xq, xq.copy() with st.spinner("Connecting DB..."): st.session_state.meta, client = init_db() with st.spinner("Loading Models..."): # Initialize SAGE model if 'xq' not in st.session_state: model, alphabet = init_esm() batch_converter = alphabet.get_batch_converter() st.session_state['batch'] = batch_converter st.session_state.query_num = 0 if 'xq' not in st.session_state: # If it's a fresh start if st.session_state.query_num < len(messages): msg = messages[0] else: msg = messages[-1] with st.container(): st.title("Evolutionary Scale Modeling") start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] start[0].info(msg) option = st.selectbox('Application options', ('self-contact prediction', 'search the database', 'activity prediction','PDB viewer')) st.session_state.db_name_ref = 'default.esm_protein' if option == 'self-contact prediction': sequence = st.text_input('protein sequence', '') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('') start[2] = st.pyplot(perdict_contact_visualization(sequence, model, batch_converter)) expander = st.expander("See explanation") expander.text("""Contact prediction is based on a logistic regression over the model's attention maps. \ This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020) The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.""") st.session_state['xq'] = model elif option == 'search the database': sequence = st.text_input('protein sequence', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ', sequence) result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) st.text('search result: ') # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) if st.button(result_temp_seq[0]): print(result_temp_seq[0]) elif st.button(result_temp_seq[1]): print(result_temp_seq[1]) elif st.button(result_temp_seq[2]): print(result_temp_seq[2]) elif st.button(result_temp_seq[3]): print(result_temp_seq[3]) elif st.button(result_temp_seq[4]): print(result_temp_seq[4]) start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) st.session_state['xq'] = model elif option == 'activity prediction': st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') sequence = st.text_input('protein sequence', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' elif option == 'PDB viewer': id_PDB = st.text_input('enter PDB ID', '') residues_marker = st.text_input('residues class', '') if residues_marker: start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) else: start[3] = showmol(render_pdb(id = id_PDB)) st.session_state['xq'] = model else: if st.session_state.query_num < len(messages): msg = messages[0] else: msg = messages[-1] with st.container(): st.title("Evolutionary Scale Modeling") start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] start[0].info(msg) option = st.selectbox('Application options', ('self-contact prediction', 'search the database', 'activity prediction','PDB viewer')) st.session_state.db_name_ref = 'default.esm_protein' if option == 'self-contact prediction': sequence = st.text_input('protein sequence', '') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ',sequence) start[2] = st.pyplot(perdict_contact_visualization(sequence, st.session_state['xq'], st.session_state['batch'])) expander = st.expander("See explanation") expander.markdown( """Contact prediction is based on a logistic regression over the model's attention maps. This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020)The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way. """, unsafe_allow_html=True) elif option == 'search the database': sequence = st.text_input('protein sequence', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ', sequence) result_temp_seq = esm_search(st.session_state['xq'], sequence, st.session_state['batch'] ,top_k=10) st.text('search result (top 5): ') # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) tab1, tab2, tab3 , tab4, tab5 = st.tabs(['1','2','3','4','5']) # option2 = st.radio('top5 sequence', (result_temp_seq[0],result_temp_seq[1],result_temp_seq[2],result_temp_seq[3],result_temp_seq[4])) with tab1: st.write(result_temp_seq[0]) import random # print(random.randint(0,9)) prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS'] # protein=st.selectbox('select protein',prot_list) protein = prot_str[random.randint(14,18)] xyzview = py3Dmol.view(query='pdb:'+protein) xyzview.setStyle({'stick':{'color':'spectrum'}}) showmol(xyzview, height = 500,width=800) # st.write(result_temp_seq[4]) with tab2: import random # print(random.randint(0,9)) st.write(result_temp_seq[1]) prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS'] # protein=st.selectbox('select protein',prot_list) protein = prot_str[random.randint(0,4)] xyzview = py3Dmol.view(query='pdb:'+protein) xyzview.setStyle({'stick':{'color':'spectrum'}}) showmol(xyzview, height = 500,width=800) with tab3: st.write(result_temp_seq[2]) prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS'] # protein=st.selectbox('select protein',prot_list) protein = prot_str[random.randint(4,8)] xyzview = py3Dmol.view(query='pdb:'+protein) xyzview.setStyle({'stick':{'color':'spectrum'}}) showmol(xyzview, height = 500,width=800) with tab4: st.write(result_temp_seq[3]) prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS'] # protein=st.selectbox('select protein',prot_list) protein = prot_str[random.randint(4,8)] xyzview = py3Dmol.view(query='pdb:'+protein) xyzview.setStyle({'stick':{'color':'spectrum'}}) showmol(xyzview, height = 500,width=800) with tab5: st.write(result_temp_seq[4]) prot_str=['1A2C','1BML','1D5M','1D5X','1D5Z','1D6E','1DEE','1E9F','1FC2','1FCC','1G4U','1GZS','1HE1','1HEZ','1HQR','1HXY','1IBX','1JBU','1JWM','1JWS'] # protein=st.selectbox('select protein',prot_list) protein = prot_str[random.randint(4,8)] xyzview = py3Dmol.view(query='pdb:'+protein) xyzview.setStyle({'stick':{'color':'spectrum'}}) showmol(xyzview, height = 500,width=800) elif option == 'activity prediction': st.markdown('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') # st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') sequence = st.text_input('protein sequence', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ',sequence) res_knn = KNN_search(sequence) st.subheader('KNN predictor result') start[2] = st.markdown("Activity prediction: " + str(res_knn)) elif option == 'PDB viewer': id_PDB = st.text_input('enter PDB ID', '') residues_marker = st.text_input('residues class', '') st.write('Try an example:') if st.button('PDB ID: 1A2C / residues class: ALA'): id_PDB = '1A2C' residues_marker = 'ALA' st.subheader('PDB viewer') if residues_marker: start[7] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) else: start[7] = showmol(render_pdb(id = id_PDB)) expander = st.expander("See explanation") expander.markdown(""" A PDB ID is a unique 4-character code for each entry in the Protein Data Bank. The first character must be a number between 1 and 9, and the remaining three characters can be letters or numbers. see https://www.rcsb.org/ for more information. """)