Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.append('./diffab-repo') | |
| import os | |
| import shutil | |
| import pandas as pd | |
| import yaml | |
| import subprocess | |
| import streamlit as st | |
| import stmol | |
| import py3Dmol | |
| import tempfile | |
| import re | |
| import abnumber | |
| import gzip | |
| import tarfile | |
| import torch | |
| from tqdm.auto import tqdm | |
| from Bio import PDB | |
| from collections import OrderedDict | |
| from diffab.tools.renumber import renumber as renumber_antibody | |
| from diffab.tools.renumber.run import ( | |
| biopython_chain_to_sequence, | |
| assign_number_to_sequence, | |
| ) | |
| CDR_OPTIONS = OrderedDict() | |
| CDR_OPTIONS['H_CDR1'] = 'H1' | |
| CDR_OPTIONS['H_CDR2'] = 'H2' | |
| CDR_OPTIONS['H_CDR3'] = 'H3' | |
| CDR_OPTIONS['L_CDR1'] = 'L1' | |
| CDR_OPTIONS['L_CDR2'] = 'L2' | |
| CDR_OPTIONS['L_CDR3'] = 'L3' | |
| DESIGN_MODES = OrderedDict() | |
| DESIGN_MODES['denovo'] = 'De novo design' | |
| DESIGN_MODES['denovo_dock'] = 'De novo design (with HDOCK)' | |
| DESIGN_MODES['opt'] = 'Optimization' | |
| DESIGN_MODES['fixbb'] = 'Fix-backbone' | |
| MODE_CONFIG = { | |
| 'denovo': './configs/test/codesign_multicdrs.yml', | |
| 'denovo_dock': './configs/test/codesign_multicdrs.yml', | |
| 'opt': './configs/test/abopt_singlecdr.yml', | |
| 'fixbb': './configs/test/fixbb.yml', | |
| } | |
| GPU_AVAILABLE = torch.cuda.is_available() | |
| DEFAULT_NUM_SAMPLES = 5 if GPU_AVAILABLE else 1 | |
| DEFAULT_NUM_DOCKS = 3 | |
| def dict_to_func(d): | |
| def f(x): | |
| return d[x] | |
| return f | |
| def get_config(save_dir, mode, cdrs, num_samples=5, optimization_step=4): | |
| tmpl_path = MODE_CONFIG[mode] | |
| with open(tmpl_path, 'r') as f: | |
| cfg = yaml.safe_load(f) | |
| cfg['sampling']['cdrs'] = cdrs | |
| cfg['sampling']['num_samples'] = num_samples | |
| cfg['sampling']['optimize_steps'] = [optimization_step, ] | |
| save_path = os.path.join(save_dir, 'design.yml') | |
| with open(save_path, 'w') as f: | |
| yaml.dump(cfg, f) | |
| return cfg, save_path | |
| def run_design(pdb_path, config_path, output_dir, docking, display_widget, num_docks=DEFAULT_NUM_DOCKS): | |
| if docking: | |
| cmd = f"python design_dock.py --antigen {pdb_path} --config {config_path} --num_docks {num_docks} " | |
| else: | |
| cmd = f"python design_pdb.py {pdb_path} --config {config_path} " | |
| cmd += f"--batch_size 1 --out_root {output_dir} " | |
| if GPU_AVAILABLE: | |
| cmd += "--device cuda" | |
| else: | |
| cmd += "--device cpu" | |
| result_dir = os.path.join(output_dir, 'design') | |
| if os.path.exists(result_dir): | |
| shutil.rmtree(result_dir) | |
| output_buffer = '' | |
| proc = subprocess.Popen( | |
| cmd, | |
| shell=True, | |
| env=os.environ.copy(), | |
| bufsize=1, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| cwd=os.getcwd(), | |
| ) | |
| for line in iter(proc.stdout.readline, b''): | |
| output_buffer += line.decode() | |
| display_widget.code( | |
| '\n'.join(output_buffer.splitlines()[-10:]), | |
| ) | |
| proc.stdout.close() | |
| proc.wait() | |
| def renumber_antibody_cached(in_pdb, out_pdb, file_id): | |
| return renumber_antibody( | |
| in_pdb, out_pdb, return_other_chains=True | |
| ) | |
| def gather_results(result_dir): | |
| outputs = [] | |
| for root, dirs, files in os.walk(result_dir): | |
| for fname in files: | |
| if not re.match('^\d\d\d\d\.pdb$', fname): | |
| continue | |
| fpath = os.path.join(root, fname) | |
| gname = os.path.basename(root) | |
| outputs.append((gname, fname, fpath)) | |
| parser = PDB.PDBParser(QUIET=True) | |
| records = [] | |
| fpath_to_name = {} | |
| for gname, fname, fpath in tqdm(outputs): | |
| name = f"{gname}_{fname}" | |
| structure = parser.get_structure(name, fpath) | |
| model = structure[0] | |
| record = { | |
| 'name': name, | |
| 'H1': None, 'H2': None, 'H3': None, | |
| 'L1': None, 'L2': None, 'L3': None, | |
| 'gname': gname, 'fname': fname, 'fpath': fpath, | |
| } | |
| for chain in model: | |
| try: | |
| seq, reslist = biopython_chain_to_sequence(chain) | |
| numbers, abchain = assign_number_to_sequence(seq) | |
| if abchain.chain_type == 'H': | |
| record['H1'] = abchain.cdr1_seq | |
| record['H2'] = abchain.cdr2_seq | |
| record['H3'] = abchain.cdr3_seq | |
| elif abchain.chain_type in ('L', 'K'): | |
| record['L1'] = abchain.cdr1_seq | |
| record['L2'] = abchain.cdr2_seq | |
| record['L3'] = abchain.cdr3_seq | |
| except abnumber.ChainParseError as e: | |
| pass | |
| records.append(record) | |
| fpath_to_name[fpath] = name | |
| with tarfile.open(os.path.join(result_dir, 'generated.tar.gz'), 'w:gz') as tar: | |
| for record in records: | |
| info = tar.gettarinfo(record['fpath']) | |
| info.name = record['name'] | |
| tar.addfile( | |
| tarinfo = info, | |
| fileobj = open(record['fpath'], 'rb'), | |
| ) | |
| records = pd.DataFrame(records) | |
| return records, fpath_to_name | |
| def main(): | |
| # Temporary workspace directory | |
| if 'tempdir_path' not in st.session_state: | |
| tempdir_path = tempfile.mkdtemp(prefix='streamlit') | |
| st.session_state.tempdir_path = tempdir_path | |
| else: | |
| tempdir_path = st.session_state.tempdir_path | |
| # Page layout | |
| st.set_page_config(layout="wide") | |
| st.markdown( | |
| "# DiffAb \n\n" | |
| "Antigen-Specific Antibody Design and Optimization with Diffusion-Based Generative Models for Protein Structures (NeurIPS 2022) \n\n" | |
| "[[Paper](https://www.biorxiv.org/content/10.1101/2022.07.10.499510.abstract)] " | |
| "[[Code](https://github.com/luost26/diffab)]" | |
| ) | |
| left_col, right_col = st.columns(2) | |
| # Step 1: Upload PDB or choose an example | |
| uploaded_file = None | |
| with left_col: | |
| uploaded_file = st.file_uploader( | |
| 'Antigen structure or antibody-antigen complex', | |
| # disabled=True | |
| ) | |
| if uploaded_file is None: | |
| with st.expander('Download examples', expanded=True): | |
| with open('./data/examples/7DK2_AB_C.pdb', 'r') as f: | |
| st.download_button( | |
| 'RBD + Antibody Complex', | |
| data = f, | |
| file_name='RBD_AbAg.pdb', | |
| ) | |
| with open('./data/examples/Omicron_RBD.pdb', 'r') as f: | |
| st.download_button( | |
| 'RBD Antigen Only', | |
| data = f, | |
| file_name = 'RBD_AgOnly.pdb', | |
| ) | |
| st.text('Please upload the downloaded PDB file to run the demo.') | |
| if 'submit' not in st.session_state: | |
| st.session_state.submit = False | |
| if 'done' not in st.session_state: | |
| st.session_state.done = False | |
| # Step 1.2: Retrieve uploaded PDB | |
| if uploaded_file is not None: | |
| pdb_path = os.path.join(tempdir_path, 'structure.pdb') | |
| renum_path = os.path.join(tempdir_path, 'structure_renumber.pdb') | |
| with open(pdb_path, 'w') as f: | |
| f.write(uploaded_file.getvalue().decode()) | |
| H_chains, L_chains, Ag_chains = renumber_antibody_cached( | |
| in_pdb = pdb_path, | |
| out_pdb = renum_path, | |
| file_id = uploaded_file.id | |
| ) | |
| H_chain = H_chains[0] if H_chains else None | |
| L_chain = L_chains[0] if L_chains else None | |
| docking = H_chain is None and L_chain is None | |
| # Step 2: Design options | |
| if uploaded_file is not None: | |
| with left_col: | |
| st.dataframe(pd.DataFrame({ | |
| 'Heavy': {'Chain': H_chain}, | |
| 'Light': {'Chain': L_chain}, | |
| 'Antigen': {'Chain': ','.join(Ag_chains)}, | |
| }), use_container_width=True) | |
| form = st.form('design_form') | |
| with form: | |
| if H_chain is None and L_chain is None: | |
| # Antigen only | |
| cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3', 'L_CDR1', 'L_CDR2', 'L_CDR3'] | |
| cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3'] | |
| mode_options = ['denovo_dock'] | |
| elif H_chain is not None and L_chain is None: | |
| # Heavy chain + Antigen | |
| cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3'] | |
| cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3'] | |
| mode_options = ['denovo', 'opt', 'fixbb'] | |
| elif H_chain is None and L_chain is not None: | |
| # Light chain + Antigen | |
| cdr_options = ['L_CDR1', 'L_CDR2', 'L_CDR3'] | |
| cdr_default = ['L_CDR1', 'L_CDR2', 'L_CDR3'] | |
| mode_options = ['denovo', 'opt', 'fixbb'] | |
| else: | |
| # H + L + Ag | |
| cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3', 'L_CDR1', 'L_CDR2', 'L_CDR3'] | |
| cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3'] | |
| mode_options = ['denovo', 'opt', 'fixbb'] | |
| design_mode = st.radio( | |
| 'Mode', | |
| mode_options, | |
| format_func=dict_to_func(DESIGN_MODES), | |
| # disabled=True, | |
| ) | |
| cdr_choices = st.multiselect( | |
| 'CDRs', | |
| cdr_options, | |
| default = cdr_default, | |
| format_func=dict_to_func(CDR_OPTIONS), | |
| # disabled=True, | |
| ) | |
| if docking: | |
| num_docks = st.slider( | |
| 'Number of docking poses', | |
| min_value=1, max_value=10, value=DEFAULT_NUM_DOCKS, | |
| ) | |
| else: | |
| num_docks = 0 | |
| num_designs = st.slider( | |
| 'Number of samples', | |
| min_value=1, max_value=10, value=DEFAULT_NUM_SAMPLES, | |
| ) | |
| submit = st.form_submit_button('Run') | |
| st.session_state.submit = st.session_state.submit or submit | |
| if submit: | |
| st.session_state.done = False | |
| # Step 3: Prepare configuration and run design | |
| if uploaded_file is not None and st.session_state.submit: | |
| config, config_path = get_config( | |
| save_dir = tempdir_path, | |
| mode = design_mode, | |
| cdrs = cdr_choices, | |
| num_samples = num_designs, | |
| ) | |
| with right_col: | |
| result_molecule_display = st.empty() | |
| result_select_widget = st.empty() | |
| result_table_display = st.empty() | |
| result_download_btn = st.empty() | |
| output_display = st.empty() | |
| if not st.session_state.done: | |
| run_design( | |
| pdb_path = renum_path, | |
| config_path = config_path, | |
| output_dir = tempdir_path, | |
| docking = docking, | |
| display_widget = output_display, | |
| num_docks = num_docks, | |
| ) | |
| st.session_state.done = True | |
| result_dir = os.path.join(tempdir_path, 'design') | |
| df_cols = ['name'] + list(CDR_OPTIONS.values()) | |
| df_results, fpath_to_name = gather_results(result_dir) | |
| st.session_state.results = (df_results, fpath_to_name) | |
| # Step 5: Show results: | |
| if st.session_state.submit and st.session_state.done: | |
| result_dir = os.path.join(tempdir_path, 'design') | |
| df_results, fpath_to_name = st.session_state.results | |
| df_cols = ['name'] + list(CDR_OPTIONS.values()) | |
| result_table_display.dataframe(df_results[df_cols], use_container_width=True) | |
| display_pdb_path = result_select_widget.selectbox( | |
| label = "Visualize", | |
| options = df_results['fpath'], | |
| format_func = dict_to_func(fpath_to_name), | |
| ) | |
| with open(os.path.join(result_dir, 'generated.tar.gz'), 'rb') as f: | |
| result_download_btn.download_button( | |
| label = "Download PDBs", | |
| data = f, | |
| file_name = "generated.tar.gz", | |
| ) | |
| if not os.path.exists(display_pdb_path): | |
| display_pdb_path = df_results['fpath'][0] | |
| with open(display_pdb_path, 'r') as f: | |
| pdb_str = f.read() | |
| xyzview = py3Dmol.view(width=380, height=380) | |
| xyzview.addModelsAsFrames(pdb_str) | |
| xyzview.setStyle({'cartoon':{'color':'spectrum'}}) | |
| xyzview.zoomTo() | |
| with result_molecule_display: | |
| stmol.showmol(xyzview, width=380, height=380) | |
| if __name__ == '__main__': | |
| main() | |