Emaad's picture
Update app.py
95c9069
raw
history blame
6.95 kB
import os
import gradio as gr
from prediction import run_sequence_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from celle_main import instantiate_from_config
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
def bold_predicted_letters(input_string: str, output_string: str) -> str:
result = []
i = 0
while i < len(input_string):
if input_string[i:i+6] == "<mask>":
result.append("**" + output_string[i] + "**")
i += 6
else:
result.append(output_string[i])
i += 1
return "".join(result)
class model:
def __init__(self):
self.model = None
self.model_name = None
self.model_dict = {}
def gradio_demo(self, model_name, sequence_input, image):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.model_name != model_name:
self.model_name = model_name
if self.model_name not in self.model_dict.keys():
model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
self.model_dict.update({self.model_name:[model_ckpt_path, model_config_path]})
else:
model_ckpt_path, model_config_path = self.model_dict[self.model_name]
# Load model config and set ckpt_path if not provided in config
config = OmegaConf.load(model_config_path)
if config["model"]["params"]["ckpt_path"] is None:
config["model"]["params"]["ckpt_path"] = model_ckpt_path
# Set condition_model_path and vqgan_model_path to None
config["model"]["params"]["condition_model_path"] = None
config["model"]["params"]["vqgan_model_path"] = None
base_path = os.getcwd()
os.chdir(os.path.dirname(model_ckpt_path))
# Instantiate model from config and move to device
self.model = instantiate_from_config(config.model).to(device)
self.model = torch.compile(self.model,mode='max-autotune')
os.chdir(base_path)
if "Finetuned" in model_name:
dataset = "OpenCell"
else:
dataset = "HPA"
nucleus_image = image['image'].convert('L')
protein_image = image['mask'].convert('L')
to_tensor = T.ToTensor()
nucleus_image = to_tensor(nucleus_image)
protein_image = to_tensor(protein_image)
stacked_images = torch.stack([nucleus_image, protein_image], dim=0)
processed_images = process_image(stacked_images, dataset)
nucleus_image = processed_images[0].unsqueeze(0)
protein_image = processed_images[1].unsqueeze(0)
protein_image = protein_image/torch.max(protein_image)
protein_image = 1 - protein_image
formatted_predicted_sequence = run_sequence_prediction(
sequence_input=sequence_input,
nucleus_image=nucleus_image,
protein_image=protein_image,
model=self.model,
device=device,
)
formatted_predicted_sequence = formatted_predicted_sequence[0]
formatted_predicted_sequence = formatted_predicted_sequence.replace("<pad>","")
formatted_predicted_sequence = formatted_predicted_sequence.replace("<cls>","")
formatted_predicted_sequence = formatted_predicted_sequence.replace("<eos>","")
print(sequence_input)
print(formatted_predicted_sequence)
formatted_predicted_sequence = bold_predicted_letters(sequence_input, formatted_predicted_sequence)
return T.ToPILImage()(protein_image[0,0]), T.ToPILImage()(nucleus_image[0,0]), formatted_predicted_sequence
base_class = model()
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown("## Inputs")
gr.Markdown("Select the prediction model. **Note the first run may take ~2-3 minutes, but will take 3-4 seconds afterwards.**")
gr.Markdown(
"- ```CELL-E_2_HPA_2560``` is a good general purpose model for various cell types using ICC-IF."
)
gr.Markdown(
"- ```CELL-E_2_OpenCell_2560``` is trained on OpenCell and is good more live-cell predictions on HEK cells."
)
with gr.Row():
model_name = gr.Dropdown(
["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
value="CELL-E_2_HPA_2560",
label="Model Name",
)
with gr.Row():
gr.Markdown(
"Input the desired amino acid sequence. GFP is shown below by default. The sequence must include ```<mask>``` for a prediction to be run."
)
with gr.Row():
sequence_input = gr.Textbox(
value="M<mask><mask><mask><mask><mask>SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
label="Sequence",
)
with gr.Row():
gr.Markdown(
"Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image."
)
with gr.Row().style(equal_height=True):
nucleus_image = gr.Image(
source="upload",
tool="sketch",
label="Nucleus Image",
interactive=True,
image_mode="L",
type="pil"
)
with gr.Row():
gr.Markdown("## Outputs")
with gr.Row().style(equal_height=True):
nucleus_crop = gr.Image(
label="Nucleus Image (Crop)",
image_mode="L",
type="pil"
)
mask = gr.Image(
label="Threshold Image",
image_mode="L",
type="pil"
)
with gr.Row():
gr.Markdown("Sequence predictions are show below.")
with gr.Row().style(equal_height=True):
predicted_sequence = gr.Markdown(label='Predicted Sequence')
with gr.Row():
button = gr.Button("Run Model")
inputs = [model_name, sequence_input, nucleus_image]
outputs = [mask, nucleus_crop, predicted_sequence]
button.click(base_class.gradio_demo, inputs, outputs)
demo.launch(enable_queue=True)