import gradio as gr from huggingface_hub import hf_hub_download from prediction import run_sequence_prediction import torch import torchvision.transforms as T from celle.utils import process_image from PIL import Image from matplotlib import pyplot as plt def gradio_demo(model_name, sequence_input, image): model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt") config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml") hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml") hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if "Finetuned" in model_name: dataset = "OpenCell" else: dataset = "HPA" nucleus_image = image['image'].convert('L') protein_image = image['mask'].convert('L') to_tensor = T.ToTensor() nucleus_tensor = to_tensor(nucleus_image) protein_tensor = to_tensor(protein_image) stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0) processed_images = process_image(stacked_images, dataset) nucleus_image = processed_images[0].unsqueeze(0) protein_image = processed_images[1].unsqueeze(0) protein_image = protein_image > 0 protein_image = 1.0 * protein_image print(f'{protein_image.sum()}') formatted_predicted_sequence = run_sequence_prediction( sequence_input=sequence_input, nucleus_image=nucleus_image, protein_image=protein_image, model_ckpt_path=model, model_config_path=config, device=device, ) return T.ToPILImage()(protein_image), T.ToPILImage()(nucleus_image), formatted_predicted_sequence with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("Select the prediction model.") gr.Markdown( "- CELL-E_2_HPA_2560 is a good general purpose model for various cell types using ICC-IF." ) gr.Markdown( "- CELL-E_2_OpenCell_2560 is trained on OpenCell and is good more live-cell predictions on HEK cells." ) with gr.Row(): model_name = gr.Dropdown( ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"], value="CELL-E_2_HPA_2560", label="Model Name", ) with gr.Row(): gr.Markdown( "Input the desired amino acid sequence. GFP is shown below by default. The sequence must include `````` for a prediction to be run." ) with gr.Row(): sequence_input = gr.Textbox( value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", label="Sequence", ) with gr.Row(): gr.Markdown( "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image." ) with gr.Row().style(equal_height=True): nucleus_image = gr.Image( source="upload", tool="sketch", invert_colors=True, label="Nucleus Image", interactive=True, image_mode="L", type="pil" ) with gr.Row().style(equal_height=True): nucleus_crop = gr.Image( label="Nucleus Image (Crop)", image_mode="L", type="pil" ) mask = gr.Image( label="Threshold Image", image_mode="L", type="pil" ) with gr.Row(): gr.Markdown("Sequence predictions are show below.") with gr.Row().style(equal_height=True): predicted_sequence = gr.Textbox(label='Predicted Sequence') with gr.Row(): button = gr.Button("Run Model") inputs = [model_name, sequence_input, nucleus_image] outputs = [mask, nucleus_crop, predicted_sequence] button.click(gradio_demo, inputs, outputs) demo.launch(enable_queue=True)