Emaad commited on
Commit
61dc572
1 Parent(s): 1a1a419

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -70
app.py CHANGED
@@ -1,29 +1,23 @@
1
- import os
2
  import gradio as gr
3
- from prediction import run_image_prediction
4
  import torch
5
  import torchvision.transforms as T
6
  from celle.utils import process_image
7
- from PIL import Image
8
- from matplotlib import pyplot as plt
9
  from celle_main import instantiate_from_config
10
  from omegaconf import OmegaConf
11
 
12
-
13
  class model:
14
  def __init__(self):
15
  self.model = None
16
  self.model_name = None
17
 
18
- def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
  if self.model_name != model_name:
22
  self.model_name = model_name
23
  model_ckpt_path = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
24
  model_config_path = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
25
-
26
-
27
  # Load model config and set ckpt_path if not provided in config
28
  config = OmegaConf.load(model_config_path)
29
  if config["model"]["params"]["ckpt_path"] is None:
@@ -42,115 +36,110 @@ class model:
42
  self.model = torch.compile(self.model,mode='reduce-overhead')
43
 
44
  os.chdir(base_path)
45
-
46
-
47
  if "Finetuned" in model_name:
48
  dataset = "OpenCell"
49
 
50
  else:
51
  dataset = "HPA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- nucleus_image = process_image(nucleus_image, dataset, "nucleus")
54
- if protein_image:
55
- protein_image = process_image(protein_image, dataset, "protein")
56
- protein_image = protein_image > torch.median(protein_image)
57
- protein_image = protein_image[0, 0]
58
- protein_image = protein_image * 1.0
59
- else:
60
- protein_image = torch.ones((256, 256))
61
-
62
- threshold, heatmap = run_image_prediction(
63
  sequence_input=sequence_input,
64
  nucleus_image=nucleus_image,
65
- model=self.model,
 
66
  device=device,
67
  )
 
 
68
 
69
- # Plot the heatmap
70
- plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
71
- plt.axis("off")
72
-
73
- # Save the plot to a temporary file
74
- plt.savefig("temp.png", bbox_inches="tight", dpi=256)
75
-
76
- # Open the temporary file as a PIL image
77
- heatmap = Image.open("temp.png")
78
-
79
- return (
80
- T.ToPILImage()(nucleus_image[0, 0]),
81
- T.ToPILImage()(protein_image),
82
- T.ToPILImage()(threshold),
83
- heatmap,
84
- )
85
-
86
- base_class = model()
87
 
88
  with gr.Blocks(theme='gradio/soft') as demo:
89
  gr.Markdown("Select the prediction model.")
90
  gr.Markdown(
91
- "CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF."
92
  )
93
  gr.Markdown(
94
- "CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells."
95
  )
96
  with gr.Row():
97
  model_name = gr.Dropdown(
98
- ["CELL-E_2_HPA_480", "CELL-E_2_HPA_Finetuned_480"],
99
- value="CELL-E_2_HPA_480",
100
  label="Model Name",
101
  )
102
  with gr.Row():
103
  gr.Markdown(
104
- "Input the desired amino acid sequence. GFP is shown below by default."
105
  )
106
 
107
  with gr.Row():
108
  sequence_input = gr.Textbox(
109
- value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
110
  label="Sequence",
111
  )
112
  with gr.Row():
113
  gr.Markdown(
114
- "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)"
115
  )
116
- gr.Markdown("The protein image is optional and is just used for display.")
117
 
118
  with gr.Row().style(equal_height=True):
119
  nucleus_image = gr.Image(
120
- type="pil",
121
- label="Nucleus Image",
 
 
 
122
  image_mode="L",
 
123
  )
124
 
125
- protein_image = gr.Image(type="pil", label="Protein Image (Optional)")
126
-
127
- with gr.Row():
128
- gr.Markdown("Image predictions are show below.")
129
 
130
  with gr.Row().style(equal_height=True):
131
- nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L")
132
-
133
- protein_threshold_image = gr.Image(
134
- type="pil", label="Protein Threshold Image", image_mode="L"
135
  )
136
 
137
- predicted_threshold_image = gr.Image(
138
- type="pil", label="Predicted Threshold image", image_mode="L"
 
 
139
  )
 
 
 
 
 
 
140
 
141
- predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap")
142
  with gr.Row():
143
  button = gr.Button("Run Model")
144
 
145
- inputs = [model_name, sequence_input, nucleus_image, protein_image]
146
 
147
- outputs = [
148
- nucleus_image_crop,
149
- protein_threshold_image,
150
- predicted_threshold_image,
151
- predicted_heatmap,
152
- ]
153
 
154
- button.click(base_class.gradio_demo, inputs, outputs)
155
 
156
- demo.launch(enable_queue=True)
 
 
1
  import gradio as gr
2
+ from prediction import run_sequence_prediction
3
  import torch
4
  import torchvision.transforms as T
5
  from celle.utils import process_image
 
 
6
  from celle_main import instantiate_from_config
7
  from omegaconf import OmegaConf
8
 
 
9
  class model:
10
  def __init__(self):
11
  self.model = None
12
  self.model_name = None
13
 
14
+ def gradio_demo(self, model_name, sequence_input, image):
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
16
  if self.model_name != model_name:
17
  self.model_name = model_name
18
  model_ckpt_path = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
19
  model_config_path = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
20
+
 
21
  # Load model config and set ckpt_path if not provided in config
22
  config = OmegaConf.load(model_config_path)
23
  if config["model"]["params"]["ckpt_path"] is None:
 
36
  self.model = torch.compile(self.model,mode='reduce-overhead')
37
 
38
  os.chdir(base_path)
39
+
40
+
41
  if "Finetuned" in model_name:
42
  dataset = "OpenCell"
43
 
44
  else:
45
  dataset = "HPA"
46
+
47
+
48
+ nucleus_image = image['image'].convert('L')
49
+ protein_image = image['mask'].convert('L')
50
+
51
+ to_tensor = T.ToTensor()
52
+ nucleus_tensor = to_tensor(nucleus_image)
53
+ protein_tensor = to_tensor(protein_image)
54
+ stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0)
55
+ processed_images = process_image(stacked_images, dataset)
56
+
57
+ nucleus_image = processed_images[0].unsqueeze(0)
58
+ protein_image = processed_images[1].unsqueeze(0)
59
+ protein_image = protein_image > 0
60
+ protein_image = 1.0 * protein_image
61
+
62
+ print(f'{protein_image.sum()}')
63
+
64
 
65
+ formatted_predicted_sequence = run_sequence_prediction(
 
 
 
 
 
 
 
 
 
66
  sequence_input=sequence_input,
67
  nucleus_image=nucleus_image,
68
+ protein_image=protein_image,
69
+ model_ckpt_path=self.model,
70
  device=device,
71
  )
72
+
73
+ return T.ToPILImage()(protein_image), T.ToPILImage()(nucleus_image), formatted_predicted_sequence
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  with gr.Blocks(theme='gradio/soft') as demo:
77
  gr.Markdown("Select the prediction model.")
78
  gr.Markdown(
79
+ "- CELL-E_2_HPA_2560 is a good general purpose model for various cell types using ICC-IF."
80
  )
81
  gr.Markdown(
82
+ "- CELL-E_2_OpenCell_2560 is trained on OpenCell and is good more live-cell predictions on HEK cells."
83
  )
84
  with gr.Row():
85
  model_name = gr.Dropdown(
86
+ ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
87
+ value="CELL-E_2_HPA_2560",
88
  label="Model Name",
89
  )
90
  with gr.Row():
91
  gr.Markdown(
92
+ "Input the desired amino acid sequence. GFP is shown below by default. The sequence must include ```<mask>``` for a prediction to be run."
93
  )
94
 
95
  with gr.Row():
96
  sequence_input = gr.Textbox(
97
+ value="M<mask><mask><mask><mask><mask>SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
98
  label="Sequence",
99
  )
100
  with gr.Row():
101
  gr.Markdown(
102
+ "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image."
103
  )
 
104
 
105
  with gr.Row().style(equal_height=True):
106
  nucleus_image = gr.Image(
107
+ source="upload",
108
+ tool="sketch",
109
+ invert_colors=True,
110
+ label="Nucleus Image",
111
+ interactive=True,
112
  image_mode="L",
113
+ type="pil"
114
  )
115
 
 
 
 
 
116
 
117
  with gr.Row().style(equal_height=True):
118
+ nucleus_crop = gr.Image(
119
+ label="Nucleus Image (Crop)",
120
+ image_mode="L",
121
+ type="pil"
122
  )
123
 
124
+ mask = gr.Image(
125
+ label="Threshold Image",
126
+ image_mode="L",
127
+ type="pil"
128
  )
129
+ with gr.Row():
130
+ gr.Markdown("Sequence predictions are show below.")
131
+
132
+ with gr.Row().style(equal_height=True):
133
+ predicted_sequence = gr.Textbox(label='Predicted Sequence')
134
+
135
 
 
136
  with gr.Row():
137
  button = gr.Button("Run Model")
138
 
139
+ inputs = [model_name, sequence_input, nucleus_image]
140
 
141
+ outputs = [mask, nucleus_crop, predicted_sequence]
 
 
 
 
 
142
 
143
+ button.click(gradio_demo, inputs, outputs)
144
 
145
+ demo.launch(enable_queue=True)