import os, time, sys if not os.path.isfile("RF2_apr23.pt"): # send param download into background os.system( "(apt-get install aria2; aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/RF2_apr23.pt) &" ) if not os.path.isdir("RoseTTAFold2"): print("install RoseTTAFold2") os.system("git clone https://github.com/sokrypton/RoseTTAFold2.git") print(os.listdir("RoseTTAFold2")) os.system( "cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install ." ) os.system( "wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py" ) # install hhsuite print("install hhsuite") os.makedirs("hhsuite", exist_ok=True) os.system( f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/" ) print(os.listdir("hhsuite")) if os.path.isfile(f"RF2_apr23.pt.aria2"): print("downloading RoseTTAFold2 params") while os.path.isfile(f"RF2_apr23.pt.aria2"): time.sleep(5) os.environ["DGLBACKEND"] = "pytorch" sys.path.append("RoseTTAFold2/network") if "hhsuite" not in os.environ["PATH"]: os.environ["PATH"] += ":hhsuite/bin:hhsuite/scripts" import matplotlib.pyplot as plt import numpy as np from parsers import parse_a3m from api import run_mmseqs2 import torch from string import ascii_uppercase, ascii_lowercase import hashlib, re, os import random from Bio.PDB import * def get_hash(x): return hashlib.sha1(x.encode()).hexdigest() alphabet_list = list(ascii_uppercase + ascii_lowercase) from collections import OrderedDict, Counter import gradio as gr if not "pred" in dir(): from predict import Predictor print("compile RoseTTAFold2") model_params = "RF2_apr23.pt" if torch.cuda.is_available(): pred = Predictor(model_params, torch.device("cuda:0")) else: print("WARNING: using CPU") pred = Predictor(model_params, torch.device("cpu")) def get_unique_sequences(seq_list): unique_seqs = list(OrderedDict.fromkeys(seq_list)) return unique_seqs def get_msa(seq, jobname, cov=50, id=90, max_msa=2048, mode="unpaired_paired"): assert mode in ["unpaired", "paired", "unpaired_paired"] seqs = [seq] if isinstance(seq, str) else seq # collapse homooligomeric sequences counts = Counter(seqs) u_seqs = list(counts.keys()) u_nums = list(counts.values()) # expand homooligomeric sequences first_seq = "/".join(sum([[x] * n for x, n in zip(u_seqs, u_nums)], [])) msa = [first_seq] path = os.path.join(jobname, "msa") os.makedirs(path, exist_ok=True) if mode in ["paired", "unpaired_paired"] and len(u_seqs) > 1: print("getting paired MSA") out_paired = run_mmseqs2(u_seqs, f"{path}/", use_pairing=True) headers, sequences = [], [] for a3m_lines in out_paired: n = -1 for line in a3m_lines.split("\n"): if len(line) > 0: if line.startswith(">"): n += 1 if len(headers) < (n + 1): headers.append([]) sequences.append([]) headers[n].append(line) else: sequences[n].append(line) # filter MSA with open(f"{path}/paired_in.a3m", "w") as handle: for n, sequence in enumerate(sequences): handle.write(f">n{n}\n{''.join(sequence)}\n") os.system( f"hhfilter -i {path}/paired_in.a3m -id {id} -cov {cov} -o {path}/paired_out.a3m" ) with open(f"{path}/paired_out.a3m", "r") as handle: for line in handle: if line.startswith(">"): n = int(line[2:]) xs = sequences[n] # expand homooligomeric sequences xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)] msa.append("/".join(xs)) if len(msa) < max_msa and ( mode in ["unpaired", "unpaired_paired"] or len(u_seqs) == 1 ): print("getting unpaired MSA") out = run_mmseqs2(u_seqs, f"{path}/") Ls = [len(seq) for seq in u_seqs] sub_idx = [] sub_msa = [] sub_msa_num = 0 for n, a3m_lines in enumerate(out): sub_msa.append([]) with open(f"{path}/in_{n}.a3m", "w") as handle: handle.write(a3m_lines) # filter os.system( f"hhfilter -i {path}/in_{n}.a3m -id {id} -cov {cov} -o {path}/out_{n}.a3m" ) with open(f"{path}/out_{n}.a3m", "r") as handle: for line in handle: if not line.startswith(">"): xs = ["-" * l for l in Ls] xs[n] = line.rstrip() # expand homooligomeric sequences xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)] sub_msa[-1].append("/".join(xs)) sub_msa_num += 1 sub_idx.append(list(range(len(sub_msa[-1])))) while len(msa) < max_msa and sub_msa_num > 0: for n in range(len(sub_idx)): if len(sub_idx[n]) > 0: msa.append(sub_msa[n][sub_idx[n].pop(0)]) sub_msa_num -= 1 if len(msa) == max_msa: break with open(f"{jobname}/msa.a3m", "w") as handle: for n, sequence in enumerate(msa): handle.write(f">n{n}\n{sequence}\n") from Bio.PDB.PDBExceptions import PDBConstructionWarning import warnings from Bio.PDB import * import numpy as np def add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname): pdb_parser = PDBParser() warnings.filterwarnings("ignore", category=PDBConstructionWarning) structure = pdb_parser.get_structure( "pdb", f"{jobname}/rf2_seed{best_seed}_00_pred.pdb" ) io = MMCIFIO() io.set_structure(structure) io.save(f"{jobname}/rf2_seed{best_seed}_00_pred.cif") plddt_cif = f"""# loop_ _ma_qa_metric.id _ma_qa_metric.mode _ma_qa_metric.name _ma_qa_metric.software_group_id _ma_qa_metric.type 1 global pLDDT 1 pLDDT 2 local pLDDT 1 pLDDT # _ma_qa_metric_global.metric_id 1 _ma_qa_metric_global.metric_value {best_plddt:.3f} _ma_qa_metric_global.model_id 1 _ma_qa_metric_global.ordinal_id 1 # loop_ _ma_qa_metric_local.label_asym_id _ma_qa_metric_local.label_comp_id _ma_qa_metric_local.label_seq_id _ma_qa_metric_local.metric_id _ma_qa_metric_local.metric_value _ma_qa_metric_local.model_id _ma_qa_metric_local.ordinal_id""" for chain in structure[0]: for i, residue in enumerate(chain): plddt_cif += f"\n{chain.id} {residue.resname} {residue.id[1]} 2 {best_plddts[i]*100:.2f} 1 {residue.id[1]}" plddt_cif += "\n#" with open(f"{jobname}/rf2_seed{best_seed}_00_pred.cif", "a") as f: f.write(plddt_cif) def predict( sequence, jobname, sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models, mode="web", ): if os.path.exists("/home/user/app"): # crude check if on spaces if len(sequence) > 600: raise gr.Error( f"Your sequence is too long ({len(sequence)}). " "Please use the full version of RoseTTAfold2 directly from GitHub." ) random_seed = int(random_seed) num_models = int(num_models) max_msa = int(max_msa) num_recycles = int(num_recycles) order = int(order) max_extra_msa = max_msa * 8 print("sequence", sequence) sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper()) sequence = re.sub(":+", ":", sequence) sequence = re.sub("^[:]+", "", sequence) sequence = re.sub("[:]+$", "", sequence) print("sequence", sequence) if sym in ["X", "C"]: copies = int(order) elif sym in ["D"]: copies = int(order) * 2 else: copies = {"T": 12, "O": 24, "I": 60}[sym] order = "" symm = sym + str(order) sequences = sequence.replace(":", "/").split("/") if collapse_identical: u_sequences = get_unique_sequences(sequences) else: u_sequences = sequences sequences = sum([u_sequences] * copies, []) lengths = [len(s) for s in sequences] # TODO subcrop = 1000 if sum(lengths) > 1400 else -1 sequence = "/".join(sequences) jobname = jobname + "_" + symm + "_" + get_hash(sequence)[:5] print(f"jobname: {jobname}") print(f"lengths: {lengths}") print("final_sequence", u_sequences) os.makedirs(jobname, exist_ok=True) if msa_method == "mmseqs2": get_msa(u_sequences, jobname, mode=pair_mode, max_msa=max_extra_msa) elif msa_method == "single_sequence": u_sequence = "/".join(u_sequences) with open(f"{jobname}/msa.a3m", "w") as a3m: a3m.write(f">{jobname}\n{u_sequence}\n") # elif msa_method == "custom_a3m": # print("upload custom a3m") # # msa_dict = files.upload() # lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines() # a3m_lines = [] # for line in lines: # line = line.replace("\x00", "") # if len(line) > 0 and not line.startswith("#"): # a3m_lines.append(line) # with open(f"{jobname}/msa.a3m", "w") as a3m: # a3m.write("\n".join(a3m_lines)) best_plddt = None best_seed = None for seed in range(int(random_seed), int(random_seed) + int(num_models)): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) npz = f"{jobname}/rf2_seed{seed}_00.npz" mlm = 0.15 if use_mlm else 0 print("MLM", mlm, use_mlm) pred.predict( inputs=[f"{jobname}/msa.a3m"], out_prefix=f"{jobname}/rf2_seed{seed}", symm=symm, ffdb=None, # TODO (templates), n_recycles=num_recycles, msa_mask=0.15 if use_mlm else 0, msa_concat_mode=msa_concat_mode, nseqs=max_msa, nseqs_full=max_extra_msa, subcrop=subcrop, is_training=use_dropout, ) plddt = np.load(npz)["lddt"].mean() if best_plddt is None or plddt > best_plddt: best_plddt = plddt best_plddts = np.load(npz)["lddt"] best_seed = seed if mode == "web": # Mol* only displays AlphaFold plDDT if they are in a cif. pdb_parser = PDBParser() mmcif_parser = MMCIFParser() plddt_cif = add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname) return f"{jobname}/rf2_seed{best_seed}_00_pred.cif" else: # for api just return a pdb file return f"{jobname}/rf2_seed{best_seed}_00_pred.pdb" def predict_api( sequence, jobname, sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models, ): filename = predict( sequence, jobname, sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models, mode="api", ) with open(f"{filename}") as fp: return fp.read() def molecule(input_pdb, public_link): print(input_pdb) print(public_link + "/file=" + input_pdb) link = public_link + "/file=" + input_pdb x = ( """
os.system('wget https://huggingface.co/spaces/simonduerr/rosettafold2/raw/main/rosettafold_pymol.py')
run rosettafold_pymol.py
rosettafold2 sequence, jobname, [sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models]
color_plddt jobname
"""
)
sequence = gr.Textbox(
label="sequence",
value="PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK",
)
jobname = gr.Textbox(label="jobname", value="test")
with gr.Accordion("Additional settings", open=False):
sym = gr.Textbox(label="sym", value="X")
order = gr.Slider(label="order", value=1, step=1, minimum=1, maximum=12)
msa_concat_mode = gr.Dropdown(
label="msa_concat_mode",
value="default",
choices=["diag", "repeat", "default"],
)
msa_method = gr.Dropdown(
label="msa_method",
value="single_sequence",
choices=[
"mmseqs2",
"single_sequence",
], # dont allow custom a3m for now , "custom_a3m"
)
pair_mode = gr.Dropdown(
label="pair_mode",
value="unpaired_paired",
choices=["unpaired_paired", "paired", "unpaired"],
)
num_recycles = gr.Dropdown(
label="num_recycles", value="6", choices=["0", "1", "3", "6", "12", "24"]
)
use_mlm = gr.Checkbox(label="use_mlm", value=False)
use_dropout = gr.Checkbox(label="use_dropout", value=False)
collapse_identical = gr.Checkbox(label="collapse_identical", value=False)
max_msa = gr.Dropdown(
choices=["16", "32", "64", "128", "256", "512"],
value="16",
label="max_msa",
)
random_seed = gr.Textbox(label="random_seed", value=0)
num_models = gr.Dropdown(
label="num_models", value="1", choices=["1", "2", "4", "8", "16", "32"]
)
btn = gr.Button("Run", visible=False)
btn_web = gr.Button("Run")
output_plain = gr.HTML()
output = gr.HTML()
btn.click(
fn=predict_api,
inputs=[
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
],
outputs=output_plain,
api_name="rosettafold2",
)
btn_web.click(
fn=predict_web,
inputs=[
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
],
outputs=output,
)
rosettafold.launch()