File size: 5,764 Bytes
a3de977
548170b
61dc572
548170b
 
 
a888fd4
 
65bf846
548170b
a888fd4
 
 
 
548170b
61dc572
a888fd4
 
 
762f591
13b52c8
762f591
 
61dc572
a888fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc9b806
a888fd4
 
61dc572
 
a888fd4
 
 
 
 
61dc572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a888fd4
61dc572
a888fd4
 
61dc572
 
a888fd4
 
61dc572
 
a888fd4
77e6565
548170b
2245755
46023a4
 
548170b
46023a4
548170b
 
46023a4
548170b
 
 
61dc572
 
548170b
 
 
 
61dc572
548170b
 
 
 
61dc572
548170b
 
 
 
61dc572
548170b
 
 
 
61dc572
 
 
 
 
548170b
61dc572
548170b
46023a4
 
 
 
548170b
61dc572
 
 
 
a888fd4
b3933a0
61dc572
 
 
 
a888fd4
61dc572
 
 
 
 
 
548170b
 
 
 
61dc572
548170b
61dc572
548170b
77e6565
548170b
61dc572
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import gradio as gr
from prediction import run_sequence_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from celle_main import instantiate_from_config
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download

class model:
    def __init__(self):
        self.model = None
        self.model_name = None

    def gradio_demo(self, model_name, sequence_input, image):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.model_name != model_name:
            self.model_name = model_name
            model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
            model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
            hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
            hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")

            # Load model config and set ckpt_path if not provided in config
            config = OmegaConf.load(model_config_path)
            if config["model"]["params"]["ckpt_path"] is None:
                config["model"]["params"]["ckpt_path"] = model_ckpt_path

            # Set condition_model_path and vqgan_model_path to None
            config["model"]["params"]["condition_model_path"] = None
            config["model"]["params"]["vqgan_model_path"] = None
            
            base_path = os.getcwd()

            os.chdir(os.path.dirname(model_ckpt_path))

            # Instantiate model from config and move to device
            self.model = instantiate_from_config(config.model).to(device)
            self.model = torch.compile(self.model,mode='max-autotune')
            
            os.chdir(base_path)
            
            
        if "Finetuned" in model_name:
            dataset = "OpenCell"

        else:
            dataset = "HPA"
            
            
        nucleus_image = image['image'].convert('L')
        protein_image = image['mask'].convert('L')

        to_tensor = T.ToTensor()
        nucleus_tensor = to_tensor(nucleus_image)
        protein_tensor = to_tensor(protein_image)
        stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0)
        processed_images = process_image(stacked_images, dataset)

        nucleus_image = processed_images[0].unsqueeze(0)
        protein_image = processed_images[1].unsqueeze(0)
        protein_image = protein_image > 0
        protein_image = 1.0 * protein_image
        
        print(f'{protein_image.sum()}')
        

        formatted_predicted_sequence = run_sequence_prediction(
            sequence_input=sequence_input,
            nucleus_image=nucleus_image,
            protein_image=protein_image,
            model_ckpt_path=self.model,
            device=device,
        )
        
        return T.ToPILImage()(protein_image), T.ToPILImage()(nucleus_image), formatted_predicted_sequence

base_class = model()

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown("## Inputs")
    gr.Markdown("Select the prediction model. **Note the first run may take ~5-6 minutes, but will take 3-4 seconds afterwards.**")
    gr.Markdown(
        "- ```CELL-E_2_HPA_2560``` is a good general purpose model for various cell types using ICC-IF."
    )
    gr.Markdown(
        "- ```CELL-E_2_OpenCell_2560``` is trained on OpenCell and is good more live-cell predictions on HEK cells."
    )
    with gr.Row():
        model_name = gr.Dropdown(
            ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
            value="CELL-E_2_HPA_2560",
            label="Model Name",
        )
    with gr.Row():
        gr.Markdown(
            "Input the desired amino acid sequence. GFP is shown below by default. The sequence must include ```<mask>``` for a prediction to be run."
        )

    with gr.Row():
        sequence_input = gr.Textbox(
            value="M<mask><mask><mask><mask><mask>SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
            label="Sequence",
        )
    with gr.Row():
        gr.Markdown(
            "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image."
        )

    with gr.Row().style(equal_height=True):
        nucleus_image = gr.Image(
            source="upload", 
            tool="sketch",
            invert_colors=True,
            label="Nucleus Image", 
            interactive=True,
            image_mode="L",
            type="pil"
        )
        
    with gr.Row():
        gr.Markdown("## Outputs")
        
    with gr.Row().style(equal_height=True):
        nucleus_crop = gr.Image(
            label="Nucleus Image (Crop)", 
            image_mode="L",
            type="pil"
        )

        mask = gr.Image(
            label="Threshold Image", 
            image_mode="L",
            type="pil"
        )
    with gr.Row():
        gr.Markdown("Sequence predictions are show below.")

    with gr.Row().style(equal_height=True):
        predicted_sequence = gr.Textbox(label='Predicted Sequence')


    with gr.Row():
        button = gr.Button("Run Model")

        inputs = [model_name, sequence_input, nucleus_image]

        outputs = [mask, nucleus_crop, predicted_sequence]

        button.click(base_class.gradio_demo, inputs, outputs)

demo.launch(enable_queue=True)