Spaces:
Runtime error
Runtime error
Jacob Silterra
commited on
Commit
·
6daa4cc
1
Parent(s):
74830dc
Have gradio app call inference.py
Browse files- Dockerfile +1 -1
- main.py +33 -26
- run_utils.py +85 -0
Dockerfile
CHANGED
|
@@ -3,7 +3,7 @@ FROM silterra/diffdock-pocket-dev
|
|
| 3 |
USER $APPUSER
|
| 4 |
WORKDIR $HOME/app
|
| 5 |
|
| 6 |
-
COPY --chown=$APPUSER . $HOME/app
|
| 7 |
|
| 8 |
# Expose port for web service
|
| 9 |
ENV PORT=7860
|
|
|
|
| 3 |
USER $APPUSER
|
| 4 |
WORKDIR $HOME/app
|
| 5 |
|
| 6 |
+
COPY --chown=$APPUSER: . $HOME/app
|
| 7 |
|
| 8 |
# Expose port for web service
|
| 9 |
ENV PORT=7860
|
main.py
CHANGED
|
@@ -1,45 +1,52 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
if False:
|
| 5 |
-
import requests
|
| 6 |
-
from torchvision import transforms
|
| 7 |
-
model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()
|
| 8 |
-
response = requests.get("https://git.io/JJkYN")
|
| 9 |
-
labels = response.text.split("\n")
|
| 10 |
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
inp = transforms.ToTensor()(inp).unsqueeze(0)
|
| 14 |
-
with torch.no_grad():
|
| 15 |
-
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
| 16 |
-
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|
| 17 |
-
return confidences
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
fi.write(f"args: {args}\n")
|
| 24 |
-
fi.write(f"kwargs: {kwargs}\n")
|
| 25 |
-
return output_file_path
|
| 26 |
|
| 27 |
|
| 28 |
def run():
|
| 29 |
iface = gr.Interface(
|
| 30 |
-
fn=
|
| 31 |
inputs=[
|
| 32 |
gr.File(label="Protein PDB", file_types=[".pdb"]),
|
| 33 |
gr.File(label="Ligand SDF", file_types=[".sdf"]),
|
| 34 |
-
gr.Number(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
gr.Checkbox(label="Keep Local Structures", value=True),
|
| 36 |
-
gr.Checkbox(label="Save
|
| 37 |
],
|
| 38 |
-
outputs=gr.File(label="Result")
|
| 39 |
)
|
| 40 |
|
| 41 |
-
iface.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
| 45 |
run()
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os.path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
import gradio as gr
|
| 5 |
|
| 6 |
+
import run_utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
+
def run_wrapper(protein_file, ligand_file, *args, **kwargs) -> str:
|
| 10 |
+
return run_utils.run_cli_command(protein_file.name, ligand_file.name,
|
| 11 |
+
*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def run():
|
| 15 |
iface = gr.Interface(
|
| 16 |
+
fn=run_wrapper,
|
| 17 |
inputs=[
|
| 18 |
gr.File(label="Protein PDB", file_types=[".pdb"]),
|
| 19 |
gr.File(label="Ligand SDF", file_types=[".sdf"]),
|
| 20 |
+
gr.Number(
|
| 21 |
+
label="Samples Per Complex",
|
| 22 |
+
value=1,
|
| 23 |
+
minimum=1,
|
| 24 |
+
maximum=100,
|
| 25 |
+
precision=0,
|
| 26 |
+
),
|
| 27 |
gr.Checkbox(label="Keep Local Structures", value=True),
|
| 28 |
+
gr.Checkbox(label="Save Visualisation", value=True),
|
| 29 |
],
|
| 30 |
+
outputs=gr.File(label="Result"),
|
| 31 |
)
|
| 32 |
|
| 33 |
+
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def set_env_variables():
|
| 37 |
+
if "DiffDock-Pocket-Dir" not in os.environ:
|
| 38 |
+
work_dir = os.path.abspath(os.path.join("../DiffDock-Pocket"))
|
| 39 |
+
if os.path.exists(work_dir):
|
| 40 |
+
os.environ["DiffDock-Pocket-Dir"] = work_dir
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"DiffDock-Pocket-Dir {work_dir} not found")
|
| 43 |
+
|
| 44 |
+
if "LOG_LEVEL" not in os.environ:
|
| 45 |
+
os.environ["LOG_LEVEL"] = "INFO"
|
| 46 |
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
| 49 |
+
set_env_variables()
|
| 50 |
+
run_utils.configure_logging()
|
| 51 |
+
|
| 52 |
run()
|
run_utils.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import subprocess
|
| 5 |
+
import tempfile
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def configure_logging(level=None):
|
| 12 |
+
if level is None:
|
| 13 |
+
level = getattr(logging, os.environ.get("LOG_LEVEL", "INFO"))
|
| 14 |
+
|
| 15 |
+
# Note that this sets the universal logger,
|
| 16 |
+
# which includes other libraries.
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
level=level,
|
| 19 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 20 |
+
handlers=[
|
| 21 |
+
logging.StreamHandler(), # Outputs logs to stderr by default
|
| 22 |
+
# If you also want to log to a file, uncomment the following line:
|
| 23 |
+
# logging.FileHandler('my_app.log', mode='a', encoding='utf-8')
|
| 24 |
+
]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def run_cli_command(protein_path: str, ligand: str, samples_per_complex: int,
|
| 29 |
+
keep_local_structures: bool, save_visualisation: bool, work_dir=None):
|
| 30 |
+
|
| 31 |
+
if work_dir is None:
|
| 32 |
+
work_dir = os.environ.get("DiffDock-Pocket-Dir",
|
| 33 |
+
os.path.join(os.environ["HOME"], "DiffDock-Pocket"))
|
| 34 |
+
|
| 35 |
+
command = ["python3", "inference.py", f"--protein_path={protein_path}", f"--ligand={ligand}",
|
| 36 |
+
f"--samples_per_complex={samples_per_complex}"]
|
| 37 |
+
|
| 38 |
+
# Adding boolean arguments only if they are True
|
| 39 |
+
if keep_local_structures:
|
| 40 |
+
command.append("--keep_local_structures")
|
| 41 |
+
if save_visualisation:
|
| 42 |
+
command.append("--save_visualisation")
|
| 43 |
+
|
| 44 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 45 |
+
temp_dir_path = temp_dir
|
| 46 |
+
logging.debug(f"temp dir: {temp_dir}")
|
| 47 |
+
command.append(f"--out_dir={temp_dir_path}")
|
| 48 |
+
|
| 49 |
+
# Convert command list to string for printing
|
| 50 |
+
command_str = " ".join(command)
|
| 51 |
+
logging.info(f"Executing command: {command_str}")
|
| 52 |
+
|
| 53 |
+
# Running the command
|
| 54 |
+
try:
|
| 55 |
+
result = subprocess.run(
|
| 56 |
+
command, cwd=work_dir, check=False, text=True, capture_output=True, env=os.environ
|
| 57 |
+
)
|
| 58 |
+
logging.debug(f"Command output:\n{result.stdout}")
|
| 59 |
+
if result.stderr:
|
| 60 |
+
logging.error(f"Command error:\n{result.stderr}")
|
| 61 |
+
except subprocess.CalledProcessError as e:
|
| 62 |
+
logging.error(f"An error occurred while executing the command: {e}")
|
| 63 |
+
|
| 64 |
+
# Zip the output directory
|
| 65 |
+
# Generate a unique filename using a timestamp and a UUID
|
| 66 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 67 |
+
uuid_tag = str(uuid.uuid4())[0:8]
|
| 68 |
+
unique_filename = f"output_{timestamp}_{uuid_tag}"
|
| 69 |
+
zip_base_name = os.path.join("tmp", unique_filename)
|
| 70 |
+
full_zip_path = shutil.make_archive(zip_base_name, 'zip', temp_dir)
|
| 71 |
+
|
| 72 |
+
logging.debug(f"Directory '{temp_dir}' zipped to '{full_zip_path}'")
|
| 73 |
+
|
| 74 |
+
return full_zip_path
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if False and __name__ == "__main__":
|
| 78 |
+
# Testing code
|
| 79 |
+
work_dir = os.path.expanduser("~/Projects/DiffDock-Pocket")
|
| 80 |
+
os.environ["DiffDock-Pocket-Dir"] = work_dir
|
| 81 |
+
protein_path = os.path.join(work_dir, "example_data", "3dpf_protein.pdb")
|
| 82 |
+
ligand = os.path.join(work_dir, "example_data", "3dpf_ligand.sdf")
|
| 83 |
+
|
| 84 |
+
run_cli_command(protein_path, ligand, samples_per_complex=1,
|
| 85 |
+
keep_local_structures=True, save_visualisation=True)
|