Emaad commited on
Commit
ab753bf
1 Parent(s): a888fd4

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +10 -36
prediction.py CHANGED
@@ -1,16 +1,11 @@
1
  import os
2
- os.chdir('..')
3
- base_dir = os.getcwd()
4
  from dataloader import CellLoader
5
- from celle_main import instantiate_from_config
6
- from omegaconf import OmegaConf
7
 
8
- def run_sequence_prediction(
 
9
  sequence_input,
10
  nucleus_image,
11
- protein_image,
12
- model_ckpt_path,
13
- model_config_path,
14
  device
15
  ):
16
  """
@@ -22,7 +17,6 @@ def run_sequence_prediction(
22
  :param model_ckpt_path: Path to model checkpoint
23
  :param model_config_path: Path to model config
24
  """
25
-
26
  # Instantiate dataset object
27
  dataset = CellLoader(
28
  sequence_mode="embedding",
@@ -36,40 +30,20 @@ def run_sequence_prediction(
36
  threshold="median",
37
  )
38
 
39
- # Check if sequence is provided and valid
40
- if len(sequence_input) == 0:
41
- raise ValueError("Sequence must be provided.")
42
-
43
- if "<mask>" not in sequence_input:
44
- print("Warning: Sequence does not contain any masked positions to predict.")
45
-
46
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
47
  sequence = dataset.tokenize_sequence(sequence_input)
48
 
49
- # Load model config and set ckpt_path if not provided in config
50
- config = OmegaConf.load(model_config_path)
51
- if config["model"]["params"]["ckpt_path"] is None:
52
- config["model"]["params"]["ckpt_path"] = model_ckpt_path
53
-
54
- # Set condition_model_path and vqgan_model_path to None
55
- config["model"]["params"]["condition_model_path"] = None
56
- config["model"]["params"]["vqgan_model_path"] = None
57
-
58
- os.chdir(os.path.dirname(model_ckpt_path))
59
-
60
- # Instantiate model from config and move to device
61
- model = instantiate_from_config(config.model).to(device)
62
-
63
  # Sample from model using provided sequence and nucleus image
64
- _, predicted_sequence, _ = model.celle.sample_text(
65
  text=sequence.to(device),
66
  condition=nucleus_image.to(device),
67
- image=protein_image.to(device),
68
- force_aas=True,
69
  temperature=1,
70
  progress=False,
71
  )
72
-
73
- os.chdir(base_dir)
74
 
75
- return predicted_sequence
 
 
 
 
 
1
  import os
 
 
2
  from dataloader import CellLoader
 
 
3
 
4
+
5
+ def run_image_prediction(
6
  sequence_input,
7
  nucleus_image,
8
+ model,
 
 
9
  device
10
  ):
11
  """
 
17
  :param model_ckpt_path: Path to model checkpoint
18
  :param model_config_path: Path to model config
19
  """
 
20
  # Instantiate dataset object
21
  dataset = CellLoader(
22
  sequence_mode="embedding",
 
30
  threshold="median",
31
  )
32
 
 
 
 
 
 
 
 
33
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
34
  sequence = dataset.tokenize_sequence(sequence_input)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Sample from model using provided sequence and nucleus image
37
+ _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
38
  text=sequence.to(device),
39
  condition=nucleus_image.to(device),
40
+ timesteps=1,
 
41
  temperature=1,
42
  progress=False,
43
  )
 
 
44
 
45
+ # Move predicted_threshold and predicted_heatmap to CPU and select first element of batch
46
+ predicted_threshold = predicted_threshold.cpu()[0, 0]
47
+ predicted_heatmap = predicted_heatmap.cpu()[0, 0]
48
+
49
+ return predicted_threshold, predicted_heatmap