File size: 8,585 Bytes
a3de977
548170b
61dc572
548170b
 
 
a888fd4
 
65bf846
548170b
636ed51
 
f2dd671
 
 
c6eb060
636ed51
f2dd671
c6eb060
 
 
 
 
 
 
 
636ed51
f2dd671
636ed51
f2dd671
 
c6eb060
636ed51
 
b607811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a888fd4
 
 
 
658d610
548170b
61dc572
a888fd4
 
658d610
 
8e4e91e
a888fd4
d75590d
658d610
d75590d
 
 
61dc572
a888fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc9b806
a888fd4
 
61dc572
 
a888fd4
 
 
 
 
61dc572
 
 
8d1fad8
 
61dc572
 
eb7610c
 
9d6b6a0
 
 
 
 
 
 
8d1fad8
9d6b6a0
 
5a8c0ad
7ffa41e
 
8d1fad8
9d6b6a0
a6144b9
61dc572
a888fd4
 
61dc572
c697272
a888fd4
 
a6144b9
9878cbc
636ed51
95c9069
 
4877844
b607811
 
9878cbc
a888fd4
77e6565
548170b
dbf58ba
46023a4
91fbf43
548170b
46023a4
548170b
 
46023a4
548170b
 
 
61dc572
 
548170b
 
 
 
61dc572
548170b
 
 
 
61dc572
548170b
 
 
 
3d5bca9
548170b
 
713161d
8d1fad8
 
 
 
 
 
 
 
 
 
 
a6144b9
545c29b
e965e32
548170b
46023a4
 
 
 
713161d
61dc572
 
 
 
a888fd4
b3933a0
61dc572
 
 
 
a888fd4
61dc572
 
 
713161d
b607811
 
 
 
c4be9e9
 
61dc572
548170b
 
 
 
61dc572
548170b
61dc572
548170b
77e6565
568eeba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import gradio as gr
from prediction import run_sequence_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from celle_main import instantiate_from_config
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download

def bold_predicted_letters(input_string: str, output_string: str) -> str:
    result = []
    i = j = 0
    input_string = input_string.upper()
    output_string = output_string.upper()
    
    while i < len(input_string):
        if input_string[i:i+6] == "<MASK>":
            start_index = i
            end_index = i + 6
            while end_index < len(input_string) and input_string[end_index:end_index+6] == "<MASK>":
                end_index += 6
            
            result.append("**" + output_string[j:j+(end_index-start_index)//6] + "**")
            i = end_index
            j += (end_index-start_index)//6
        else:
            result.append(input_string[i])
            i += 1
            if input_string[i-1] != "<":
                j += 1
    
    return "".join(result)

def diff_texts(string):
    new_string = []

    bold = False

    for idx, letter in enumerate(string):

        if letter == '*' and string[min(idx + 1, len(string)-1)] == '*' and bold == False:
            bold = True

        elif letter == '*' and string[min(idx + 1, len(string)-1)] == '*' and bold == True:
            bold = False
        if letter != '*':
            if bold :
                new_string.append((letter,'+'))
            else:
                new_string.append((letter, None))

    return new_string
    
class model:
    def __init__(self):
        self.model = None
        self.model_name = None
        self.model_path = None

    def gradio_demo(self, model_name, sequence_input, image):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.model_name != model_name:
            if self.model_path is not None:
                os.remove(self.model_path)
                del self.model
            self.model_name = model_name
            model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
            self.model_path = model_ckpt_path
            model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
            hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
            hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")

            # Load model config and set ckpt_path if not provided in config
            config = OmegaConf.load(model_config_path)
            if config["model"]["params"]["ckpt_path"] is None:
                config["model"]["params"]["ckpt_path"] = model_ckpt_path

            # Set condition_model_path and vqgan_model_path to None
            config["model"]["params"]["condition_model_path"] = None
            config["model"]["params"]["vqgan_model_path"] = None
            
            base_path = os.getcwd()

            os.chdir(os.path.dirname(model_ckpt_path))

            # Instantiate model from config and move to device
            self.model = instantiate_from_config(config.model).to(device)
            self.model = torch.compile(self.model,mode='max-autotune')
            
            os.chdir(base_path)
            
            
        if "Finetuned" in model_name:
            dataset = "OpenCell"

        else:
            dataset = "HPA"
            
            
        nucleus_image = image['image'].convert('L')
        #protein_image = image['mask'].split()[3]
        protein_image = image['mask'].convert('L')

        to_tensor = T.ToTensor()
        nucleus_image = to_tensor(nucleus_image)
        protein_image = to_tensor(protein_image)
        #stacked_images = torch.stack([nucleus_image, protein_image], dim=0)
        #processed_images = process_image(stacked_images, dataset)

        #nucleus_image = processed_images[0].unsqueeze(0)
        #protein_image = processed_images[1].unsqueeze(0)
        #protein_image = protein_image/torch.max(protein_image)
        #protein_image = 1 - protein_image  
        
        nucleus_image = nucleus_image.unsqueeze(0)
        nucleus_image = process_image(nucleus_image, dataset, 'nucleus')
        protein_image = protein_image.unsqueeze(0)
        print(nucleus.shape)
        print(protein_image.shape)
        #protein_image = 1.0*(protein_image > .01)
        
        print('test1')
        formatted_predicted_sequence = run_sequence_prediction(
            sequence_input=sequence_input,
            nucleus_image=nucleus_image,
            protein_image=protein_image,
            model=self.model,
            device=device,
        )
        print('test2')
        formatted_predicted_sequence = formatted_predicted_sequence[0]
        formatted_predicted_sequence = formatted_predicted_sequence.replace("<pad>","")
        formatted_predicted_sequence = formatted_predicted_sequence.replace("<cls>","")
        formatted_predicted_sequence = formatted_predicted_sequence.replace("<eos>","")

        formatted_predicted_sequence = bold_predicted_letters(sequence_input, formatted_predicted_sequence) 
        formatted_predicted_sequence = diff_texts(formatted_predicted_sequence)
        return T.ToPILImage()(protein_image[0,0]), T.ToPILImage()(nucleus_image[0,0]), formatted_predicted_sequence

base_class = model()

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown("## Inputs")
    gr.Markdown("Select the prediction model. **Note the first run may take ~2-3 minutes, but will take 3-4 seconds afterwards.**")
    gr.Markdown(
        "- ```CELL-E_2_HPA_2560``` is a good general purpose model for various cell types using ICC-IF."
    )
    gr.Markdown(
        "- ```CELL-E_2_OpenCell_2560``` is trained on OpenCell and is good more live-cell predictions on HEK cells."
    )
    with gr.Row():
        model_name = gr.Dropdown(
            ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
            value="CELL-E_2_HPA_2560",
            label="Model Name",
        )
    with gr.Row():
        gr.Markdown(
            "Input the desired amino acid sequence. GFP is shown below by default. The sequence must include ```<mask>``` for a prediction to be run."
        )

    with gr.Row():
        sequence_input = gr.Textbox(
            value="M<mask><mask><mask><mask><mask>SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
            label="Sequence",
        )
    with gr.Row():
        gr.Markdown(
            "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image. Due to Gradio limitations, this mask will draw in black."
        )

    with gr.Row(equal_height=True):
        #nucleus_image = gr.Image(
        #    source="upload", 
        #    tool="color-sketch",
        #    label="Nucleus Image", 
        #    interactive=True,
        #    image_mode="RGBA",
        #    type="pil"
        #)
        nucleus_image = gr.ImageMask(
            label = "Nucleus Image",
            interactive = "True",
            image_mode = "L",
            brush_color = "#ffffff",
            type = "pil"
        )
        
    with gr.Row():
        gr.Markdown("## Outputs")
        
    with gr.Row(equal_height=True):
        nucleus_crop = gr.Image(
            label="Nucleus Image (Crop)", 
            image_mode="L",
            type="pil"
        )

        mask = gr.Image(
            label="Threshold Image", 
            image_mode="L",
            type="pil"
        )
    with gr.Row():
        gr.Markdown("Sequence predictions are show below.")

    with gr.Row(equal_height=True):
        # predicted_sequence = gr.Markdown(label='Predicted Sequence')
        predicted_sequence = gr.HighlightedText(
        label="Predicted Sequence",
        combine_adjacent=True,
        show_legend=False,
        color_map={"+": "green"})


    with gr.Row():
        button = gr.Button("Run Model")

        inputs = [model_name, sequence_input, nucleus_image]

        outputs = [mask, nucleus_crop, predicted_sequence]

        button.click(base_class.gradio_demo, inputs, outputs)

demo.queue(max_size=1).launch()