Spaces:
Running
on
L4
Running
on
L4
import requests | |
import logging | |
import time | |
import os | |
import tarfile | |
from tqdm import tqdm | |
import random | |
logger = logging.getLogger(__name__) | |
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' | |
""" | |
Copyright notice: Code to run mmseqs2 was borrowed from ColabFold (c) 2021 Sergey Ovchinnikov under MIT License | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
def run_mmseqs2(x, prefix, use_env=True, use_filter=True, | |
use_templates=False, filter=None, pairing_strategy="greedy", | |
host_url="https://api.colabfold.com", | |
user_agent= "HF Space simonduerr/boltz-1 [email protected]"): | |
submission_endpoint = "ticket/msa" | |
headers = {} | |
if user_agent != "": | |
headers['User-Agent'] = user_agent | |
else: | |
logger.warning("No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future.") | |
def submit(seqs, mode, N=101): | |
n, query = N, "" | |
for seq in seqs: | |
query += f">{n}\n{seq}\n" | |
n += 1 | |
while True: | |
error_count = 0 | |
try: | |
# https://requests.readthedocs.io/en/latest/user/advanced/#advanced | |
# "good practice to set connect timeouts to slightly larger than a multiple of 3" | |
res = requests.post(f'{host_url}/{submission_endpoint}', data={ 'q': query, 'mode': mode }, timeout=6.02, headers=headers) | |
except requests.exceptions.Timeout: | |
logger.warning("Timeout while submitting to MSA server. Retrying...") | |
continue | |
except Exception as e: | |
error_count += 1 | |
logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") | |
logger.warning(f"Error: {e}") | |
time.sleep(5) | |
if error_count > 5: | |
raise | |
continue | |
break | |
try: | |
out = res.json() | |
except ValueError: | |
logger.error(f"Server didn't reply with json: {res.text}") | |
out = {"status":"ERROR"} | |
return out | |
def status(ID): | |
while True: | |
error_count = 0 | |
try: | |
res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02, headers=headers) | |
except requests.exceptions.Timeout: | |
logger.warning("Timeout while fetching status from MSA server. Retrying...") | |
continue | |
except Exception as e: | |
error_count += 1 | |
logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") | |
logger.warning(f"Error: {e}") | |
time.sleep(5) | |
if error_count > 5: | |
raise | |
continue | |
break | |
try: | |
out = res.json() | |
except ValueError: | |
logger.error(f"Server didn't reply with json: {res.text}") | |
out = {"status":"ERROR"} | |
return out | |
def download(ID, path): | |
error_count = 0 | |
while True: | |
try: | |
res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02, headers=headers) | |
except requests.exceptions.Timeout: | |
logger.warning("Timeout while fetching result from MSA server. Retrying...") | |
continue | |
except Exception as e: | |
error_count += 1 | |
logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") | |
logger.warning(f"Error: {e}") | |
time.sleep(5) | |
if error_count > 5: | |
raise | |
continue | |
break | |
with open(path,"wb") as out: out.write(res.content) | |
# process input x | |
seqs = [x] if isinstance(x, str) else x | |
# compatibility to old option | |
if filter is not None: | |
use_filter = filter | |
# setup mode | |
if use_filter: | |
mode = "env" if use_env else "all" | |
else: | |
mode = "env-nofilter" if use_env else "nofilter" | |
# define path | |
path = f"{prefix}_{mode}" | |
if not os.path.isdir(path): os.mkdir(path) | |
# call mmseqs2 api | |
tar_gz_file = f'{path}/out.tar.gz' | |
N,REDO = 101,True | |
# deduplicate and keep track of order | |
seqs_unique = [] | |
#TODO this might be slow for large sets | |
[seqs_unique.append(x) for x in seqs if x not in seqs_unique] | |
Ms = [N + seqs_unique.index(seq) for seq in seqs] | |
# lets do it! | |
TIME_ESTIMATE = 150 * len(seqs_unique) | |
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: | |
while REDO: | |
pbar.set_description("SUBMIT") | |
# Resubmit job until it goes through | |
out = submit(seqs_unique, mode, N) | |
while out["status"] in ["UNKNOWN", "RATELIMIT"]: | |
sleep_time = 5 + random.randint(0, 5) | |
logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") | |
# resubmit | |
time.sleep(sleep_time) | |
out = submit(seqs_unique, mode, N) | |
if out["status"] == "ERROR": | |
raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') | |
if out["status"] == "MAINTENANCE": | |
raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.') | |
# wait for job to finish | |
ID,TIME = out["id"],0 | |
pbar.set_description(out["status"]) | |
while out["status"] in ["UNKNOWN","RUNNING","PENDING"]: | |
t = 5 + random.randint(0,5) | |
logger.error(f"Sleeping for {t}s. Reason: {out['status']}") | |
time.sleep(t) | |
out = status(ID) | |
pbar.set_description(out["status"]) | |
if out["status"] == "RUNNING": | |
TIME += t | |
pbar.update(n=t) | |
#if TIME > 900 and out["status"] != "COMPLETE": | |
# # something failed on the server side, need to resubmit | |
# N += 1 | |
# break | |
if out["status"] == "COMPLETE": | |
if TIME < TIME_ESTIMATE: | |
pbar.update(n=(TIME_ESTIMATE-TIME)) | |
REDO = False | |
if out["status"] == "ERROR": | |
REDO = False | |
raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') | |
# Download results | |
download(ID, tar_gz_file) | |
a3m_files = [f"{path}/uniref.a3m"] | |
if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m") | |
# extract a3m files | |
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): | |
with tarfile.open(tar_gz_file) as tar_gz: | |
tar_gz.extractall(path) | |
# templates | |
if use_templates: | |
templates = {} | |
#print("seq\tpdb\tcid\tevalue") | |
for line in open(f"{path}/pdb70.m8","r"): | |
p = line.rstrip().split() | |
M,pdb,qid,e_value = p[0],p[1],p[2],p[10] | |
M = int(M) | |
if M not in templates: templates[M] = [] | |
templates[M].append(pdb) | |
#if len(templates[M]) <= 20: | |
# print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}") | |
template_paths = {} | |
for k,TMPL in templates.items(): | |
TMPL_PATH = f"{prefix}_{mode}/templates_{k}" | |
if not os.path.isdir(TMPL_PATH): | |
os.mkdir(TMPL_PATH) | |
TMPL_LINE = ",".join(TMPL[:20]) | |
response = None | |
while True: | |
error_count = 0 | |
try: | |
# https://requests.readthedocs.io/en/latest/user/advanced/#advanced | |
# "good practice to set connect timeouts to slightly larger than a multiple of 3" | |
response = requests.get(f"{host_url}/template/{TMPL_LINE}", stream=True, timeout=6.02, headers=headers) | |
except requests.exceptions.Timeout: | |
logger.warning("Timeout while submitting to template server. Retrying...") | |
continue | |
except Exception as e: | |
error_count += 1 | |
logger.warning(f"Error while fetching result from template server. Retrying... ({error_count}/5)") | |
logger.warning(f"Error: {e}") | |
time.sleep(5) | |
if error_count > 5: | |
raise | |
continue | |
break | |
with tarfile.open(fileobj=response.raw, mode="r|gz") as tar: | |
tar.extractall(path=TMPL_PATH) | |
os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex") | |
with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f: | |
f.write("") | |
template_paths[k] = TMPL_PATH | |
# gather a3m lines | |
a3m_lines = {} | |
for a3m_file in a3m_files: | |
update_M,M = True,None | |
for line in open(a3m_file,"r"): | |
if len(line) > 0: | |
if "\x00" in line: | |
line = line.replace("\x00","") | |
update_M = True | |
if line.startswith(">") and update_M: | |
M = int(line[1:].rstrip()) | |
update_M = False | |
if M not in a3m_lines: a3m_lines[M] = [] | |
a3m_lines[M].append(line) | |
# return results | |
a3m_lines = ["".join(a3m_lines[n]) for n in Ms] | |
if use_templates: | |
template_paths_ = [] | |
for n in Ms: | |
if n not in template_paths: | |
template_paths_.append(None) | |
#print(f"{n-N}\tno_templates_found") | |
else: | |
template_paths_.append(template_paths[n]) | |
template_paths = template_paths_ | |
return (a3m_lines, template_paths) if use_templates else a3m_lines |