File size: 6,954 Bytes
a3de977
548170b
61dc572
548170b
 
 
a888fd4
 
65bf846
548170b
636ed51
 
 
 
 
 
 
 
 
 
 
 
a888fd4
 
 
 
217e274
548170b
61dc572
a888fd4
 
 
217e274
 
 
 
 
 
 
 
61dc572
a888fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc9b806
a888fd4
 
61dc572
 
a888fd4
 
 
 
 
61dc572
 
 
 
 
 
eb7610c
 
 
61dc572
 
 
 
f293154
 
a888fd4
61dc572
a888fd4
 
61dc572
c697272
a888fd4
 
9878cbc
 
636ed51
59e3bc3
 
61dc572
4877844
 
 
 
9878cbc
a888fd4
77e6565
548170b
2245755
46023a4
91fbf43
548170b
46023a4
548170b
 
46023a4
548170b
 
 
61dc572
 
548170b
 
 
 
61dc572
548170b
 
 
 
61dc572
548170b
 
 
 
61dc572
548170b
 
 
 
61dc572
 
 
 
548170b
61dc572
548170b
46023a4
 
 
 
548170b
61dc572
 
 
 
a888fd4
b3933a0
61dc572
 
 
 
a888fd4
61dc572
 
 
 
cdcb7a0
61dc572
548170b
 
 
 
61dc572
548170b
61dc572
548170b
77e6565
548170b
61dc572
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import gradio as gr
from prediction import run_sequence_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from celle_main import instantiate_from_config
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download

def bold_predicted_letters(input_string: str, output_string: str) -> str:
    result = []
    i = 0
    while i < len(input_string):
        if input_string[i:i+6] == "<mask>":
            result.append("**" + output_string[i] + "**")
            i += 6
        else:
            result.append(output_string[i])
            i += 1
    return "".join(result)

class model:
    def __init__(self):
        self.model = None
        self.model_name = None
        self.model_dict = {}

    def gradio_demo(self, model_name, sequence_input, image):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.model_name != model_name:
            self.model_name = model_name
            if self.model_name not in self.model_dict.keys():
                model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
                model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
                hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
                hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
                self.model_dict.update({self.model_name:[model_ckpt_path, model_config_path]})
            else:
                model_ckpt_path, model_config_path = self.model_dict[self.model_name]

            # Load model config and set ckpt_path if not provided in config
            config = OmegaConf.load(model_config_path)
            if config["model"]["params"]["ckpt_path"] is None:
                config["model"]["params"]["ckpt_path"] = model_ckpt_path

            # Set condition_model_path and vqgan_model_path to None
            config["model"]["params"]["condition_model_path"] = None
            config["model"]["params"]["vqgan_model_path"] = None
            
            base_path = os.getcwd()

            os.chdir(os.path.dirname(model_ckpt_path))

            # Instantiate model from config and move to device
            self.model = instantiate_from_config(config.model).to(device)
            self.model = torch.compile(self.model,mode='max-autotune')
            
            os.chdir(base_path)
            
            
        if "Finetuned" in model_name:
            dataset = "OpenCell"

        else:
            dataset = "HPA"
            
            
        nucleus_image = image['image'].convert('L')
        protein_image = image['mask'].convert('L')

        to_tensor = T.ToTensor()
        nucleus_image = to_tensor(nucleus_image)
        protein_image = to_tensor(protein_image)
        stacked_images = torch.stack([nucleus_image, protein_image], dim=0)
        processed_images = process_image(stacked_images, dataset)

        nucleus_image = processed_images[0].unsqueeze(0)
        protein_image = processed_images[1].unsqueeze(0)
        protein_image = protein_image/torch.max(protein_image)
        protein_image = 1 - protein_image                

        formatted_predicted_sequence = run_sequence_prediction(
            sequence_input=sequence_input,
            nucleus_image=nucleus_image,
            protein_image=protein_image,
            model=self.model,
            device=device,
        )
        
        formatted_predicted_sequence = formatted_predicted_sequence[0]
        formatted_predicted_sequence = formatted_predicted_sequence.replace("<pad>","")
        formatted_predicted_sequence = formatted_predicted_sequence.replace("<start>","")
        formatted_predicted_sequence = formatted_predicted_sequence.replace("<end>","")
        
        print(sequence_input)
        print(formatted_predicted_sequence)

        formatted_predicted_sequence = bold_predicted_letters(sequence_input, formatted_predicted_sequence)                                                      
        return T.ToPILImage()(protein_image[0,0]), T.ToPILImage()(nucleus_image[0,0]), formatted_predicted_sequence

base_class = model()

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown("## Inputs")
    gr.Markdown("Select the prediction model. **Note the first run may take ~2-3 minutes, but will take 3-4 seconds afterwards.**")
    gr.Markdown(
        "- ```CELL-E_2_HPA_2560``` is a good general purpose model for various cell types using ICC-IF."
    )
    gr.Markdown(
        "- ```CELL-E_2_OpenCell_2560``` is trained on OpenCell and is good more live-cell predictions on HEK cells."
    )
    with gr.Row():
        model_name = gr.Dropdown(
            ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
            value="CELL-E_2_HPA_2560",
            label="Model Name",
        )
    with gr.Row():
        gr.Markdown(
            "Input the desired amino acid sequence. GFP is shown below by default. The sequence must include ```<mask>``` for a prediction to be run."
        )

    with gr.Row():
        sequence_input = gr.Textbox(
            value="M<mask><mask><mask><mask><mask>SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
            label="Sequence",
        )
    with gr.Row():
        gr.Markdown(
            "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image."
        )

    with gr.Row().style(equal_height=True):
        nucleus_image = gr.Image(
            source="upload", 
            tool="sketch",
            label="Nucleus Image", 
            interactive=True,
            image_mode="L",
            type="pil"
        )
        
    with gr.Row():
        gr.Markdown("## Outputs")
        
    with gr.Row().style(equal_height=True):
        nucleus_crop = gr.Image(
            label="Nucleus Image (Crop)", 
            image_mode="L",
            type="pil"
        )

        mask = gr.Image(
            label="Threshold Image", 
            image_mode="L",
            type="pil"
        )
    with gr.Row():
        gr.Markdown("Sequence predictions are show below.")

    with gr.Row().style(equal_height=True):
        predicted_sequence = gr.Markdown(label='Predicted Sequence')


    with gr.Row():
        button = gr.Button("Run Model")

        inputs = [model_name, sequence_input, nucleus_image]

        outputs = [mask, nucleus_crop, predicted_sequence]

        button.click(base_class.gradio_demo, inputs, outputs)

demo.launch(enable_queue=True)