import os import gradio as gr from prediction import run_sequence_prediction import torch import torchvision.transforms as T from celle.utils import process_image from celle_main import instantiate_from_config from omegaconf import OmegaConf from huggingface_hub import hf_hub_download def bold_predicted_letters(input_string: str, output_string: str) -> str: result = [] i = j = 0 input_string = input_string.upper() output_string = output_string.upper() while i < len(input_string): if input_string[i:i+6] == "": start_index = i end_index = i + 6 while end_index < len(input_string) and input_string[end_index:end_index+6] == "": end_index += 6 result.append("**" + output_string[j:j+(end_index-start_index)//6] + "**") i = end_index j += (end_index-start_index)//6 else: result.append(input_string[i]) i += 1 if input_string[i-1] != "<": j += 1 return "".join(result) def diff_texts(string): new_string = [] bold = False for idx, letter in enumerate(string): if letter == '*' and string[min(idx + 1, len(string)-1)] == '*' and bold == False: bold = True elif letter == '*' and string[min(idx + 1, len(string)-1)] == '*' and bold == True: bold = False if letter != '*': if bold : new_string.append((letter,'+')) else: new_string.append((letter, None)) return new_string class model: def __init__(self): self.model = None self.model_name = None self.model_path = None def gradio_demo(self, model_name, sequence_input, image): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.model_name != model_name: if self.model_path is not None: os.remove(self.model_path) del self.model self.model_name = model_name model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt") self.model_path = model_ckpt_path model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml") hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml") hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml") # Load model config and set ckpt_path if not provided in config config = OmegaConf.load(model_config_path) if config["model"]["params"]["ckpt_path"] is None: config["model"]["params"]["ckpt_path"] = model_ckpt_path # Set condition_model_path and vqgan_model_path to None config["model"]["params"]["condition_model_path"] = None config["model"]["params"]["vqgan_model_path"] = None base_path = os.getcwd() os.chdir(os.path.dirname(model_ckpt_path)) # Instantiate model from config and move to device self.model = instantiate_from_config(config.model).to(device) self.model = torch.compile(self.model,mode='max-autotune') os.chdir(base_path) if "Finetuned" in model_name: dataset = "OpenCell" else: dataset = "HPA" nucleus_image = image['image'].convert('L') protein_image = image['mask'].convert('L') to_tensor = T.ToTensor() nucleus_image = to_tensor(nucleus_image) protein_image = to_tensor(protein_image) stacked_images = torch.stack([nucleus_image, protein_image], dim=0) processed_images = process_image(stacked_images, dataset) nucleus_image = processed_images[0].unsqueeze(0) protein_image = processed_images[1].unsqueeze(0) protein_image = protein_image/torch.max(protein_image) formatted_predicted_sequence = run_sequence_prediction( sequence_input=sequence_input, nucleus_image=nucleus_image, protein_image=protein_image, model=self.model, device=device, ) print('test2') formatted_predicted_sequence = formatted_predicted_sequence[0] formatted_predicted_sequence = formatted_predicted_sequence.replace("","") formatted_predicted_sequence = formatted_predicted_sequence.replace("","") formatted_predicted_sequence = formatted_predicted_sequence.replace("","") formatted_predicted_sequence = bold_predicted_letters(sequence_input, formatted_predicted_sequence) formatted_predicted_sequence = diff_texts(formatted_predicted_sequence) return T.ToPILImage()(protein_image[0,0]), T.ToPILImage()(nucleus_image[0,0]), formatted_predicted_sequence base_class = model() with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Inputs") gr.Markdown("Select the prediction model. **Note the first run may take ~2-3 minutes, but will take 3-4 seconds afterwards.**") gr.Markdown( "- ```CELL-E_2_HPA_2560``` is a good general purpose model for various cell types using ICC-IF." ) gr.Markdown( "- ```CELL-E_2_OpenCell_2560``` is trained on OpenCell and is good more live-cell predictions on HEK cells." ) with gr.Row(): model_name = gr.Dropdown( ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"], value="CELL-E_2_HPA_2560", label="Model Name", ) with gr.Row(): gr.Markdown( "Input the desired amino acid sequence. GFP is shown below by default. The sequence must include `````` for a prediction to be run." ) with gr.Row(): sequence_input = gr.Textbox( value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", label="Sequence", ) with gr.Row(): gr.Markdown( "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image." ) with gr.Row(equal_height=True): #nucleus_image = gr.Image( # source="upload", # tool="color-sketch", # label="Nucleus Image", # interactive=True, # image_mode="RGBA", # type="pil" #) nucleus_image = gr.ImageMask( label = "Nucleus Image", interactive = "True", image_mode = "L", brush_color = "#ffffff", type = "pil" ) with gr.Row(): gr.Markdown("## Outputs") with gr.Row(equal_height=True): nucleus_crop = gr.Image( label="Nucleus Image (Crop)", image_mode="L", type="pil" ) mask = gr.Image( label="Threshold Image", image_mode="L", type="pil" ) with gr.Row(): gr.Markdown("Sequence predictions are show below.") with gr.Row(equal_height=True): # predicted_sequence = gr.Markdown(label='Predicted Sequence') predicted_sequence = gr.HighlightedText( label="Predicted Sequence", combine_adjacent=True, show_legend=False, color_map={"+": "green"}) with gr.Row(): button = gr.Button("Run Model") inputs = [model_name, sequence_input, nucleus_image] outputs = [mask, nucleus_crop, predicted_sequence] button.click(base_class.gradio_demo, inputs, outputs) demo.queue(max_size=1).launch()