File size: 5,520 Bytes
a888fd4
548170b
a888fd4
548170b
 
 
 
 
a888fd4
 
548170b
 
a888fd4
 
 
 
548170b
a888fd4
 
548170b
a888fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548170b
a888fd4
548170b
2245755
548170b
 
a888fd4
548170b
 
a888fd4
548170b
 
 
a888fd4
 
548170b
 
 
 
a888fd4
548170b
 
 
 
a888fd4
548170b
 
 
 
a888fd4
548170b
a888fd4
548170b
 
 
a888fd4
 
548170b
 
 
a888fd4
754cf17
548170b
a888fd4
548170b
 
a888fd4
 
 
 
 
b3933a0
a888fd4
 
 
548170b
a888fd4
548170b
 
 
a888fd4
548170b
a888fd4
 
 
 
 
 
548170b
a888fd4
548170b
a888fd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import gradio as gr
from prediction import run_image_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from PIL import Image
from matplotlib import pyplot as plt
from celle_main import instantiate_from_config
from omegaconf import OmegaConf


class model:
    def __init__(self):
        self.model = None
        self.model_name = None

    def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        if self.model_name != model_name:
            self.model_name = model_name
            model_ckpt_path = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
            model_config_path = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
            
            
            # Load model config and set ckpt_path if not provided in config
            config = OmegaConf.load(model_config_path)
            if config["model"]["params"]["ckpt_path"] is None:
                config["model"]["params"]["ckpt_path"] = model_ckpt_path

            # Set condition_model_path and vqgan_model_path to None
            config["model"]["params"]["condition_model_path"] = None
            config["model"]["params"]["vqgan_model_path"] = None
            
            base_path = os.getcwd()

            os.chdir(os.path.dirname(model_ckpt_path))

            # Instantiate model from config and move to device
            self.model = instantiate_from_config(config.model).to(device)
            self.model = torch.compile(self.model,mode='reduce-overhead')
            
            os.chdir(base_path)


        if "Finetuned" in model_name:
            dataset = "OpenCell"

        else:
            dataset = "HPA"

        nucleus_image = process_image(nucleus_image, dataset, "nucleus")
        if protein_image:
            protein_image = process_image(protein_image, dataset, "protein")
            protein_image = protein_image > torch.median(protein_image)
            protein_image = protein_image[0, 0]
            protein_image = protein_image * 1.0
        else:
            protein_image = torch.ones((256, 256))

        threshold, heatmap = run_image_prediction(
            sequence_input=sequence_input,
            nucleus_image=nucleus_image,
            model=self.model,
            device=device,
        )

        # Plot the heatmap
        plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
        plt.axis("off")

        # Save the plot to a temporary file
        plt.savefig("temp.png", bbox_inches="tight", dpi=256)

        # Open the temporary file as a PIL image
        heatmap = Image.open("temp.png")

        return (
            T.ToPILImage()(nucleus_image[0, 0]),
            T.ToPILImage()(protein_image),
            T.ToPILImage()(threshold),
            heatmap,
        )

base_class = model()

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown("Select the prediction model.")
    gr.Markdown(
        "CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF."
    )
    gr.Markdown(
        "CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells."
    )
    with gr.Row():
        model_name = gr.Dropdown(
            ["CELL-E_2_HPA_480", "CELL-E_2_HPA_Finetuned_480"],
            value="CELL-E_2_HPA_480",
            label="Model Name",
        )
    with gr.Row():
        gr.Markdown(
            "Input the desired amino acid sequence. GFP is shown below by default."
        )

    with gr.Row():
        sequence_input = gr.Textbox(
            value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
            label="Sequence",
        )
    with gr.Row():
        gr.Markdown(
            "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)"
        )
        gr.Markdown("The protein image is optional and is just used for display.")

    with gr.Row().style(equal_height=True):
        nucleus_image = gr.Image(
            type="pil",
            label="Nucleus Image",
            image_mode="L",
        )

        protein_image = gr.Image(type="pil", label="Protein Image (Optional)")

    with gr.Row():
        gr.Markdown("Image predictions are show below.")

    with gr.Row().style(equal_height=True):
        nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L")

        protein_threshold_image = gr.Image(
            type="pil", label="Protein Threshold Image", image_mode="L"
        )

        predicted_threshold_image = gr.Image(
            type="pil", label="Predicted Threshold image", image_mode="L"
        )

        predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap")
    with gr.Row():
        button = gr.Button("Run Model")

        inputs = [model_name, sequence_input, nucleus_image, protein_image]

        outputs = [
            nucleus_image_crop,
            protein_threshold_image,
            predicted_threshold_image,
            predicted_heatmap,
        ]

        button.click(base_class.gradio_demo, inputs, outputs)

demo.launch(enable_queue=True)