jannisborn commited on
Commit
f1e36b5
·
unverified ·
1 Parent(s): 75a6229

feat: moler updates

Browse files
Files changed (5) hide show
  1. app.py +15 -4
  2. model_cards/article.md +8 -1
  3. model_cards/examples.csv +7 -5
  4. requirements.txt +1 -1
  5. utils.py +4 -2
app.py CHANGED
@@ -17,7 +17,9 @@ TITLE = "MoLeR"
17
  def run_inference(
18
  algorithm_version: str,
19
  scaffolds: str,
 
20
  beam_size: int,
 
21
  number_of_samples: int,
22
  seed: int,
23
  ):
@@ -25,15 +27,18 @@ def run_inference(
25
  algorithm_version=algorithm_version,
26
  scaffolds=scaffolds,
27
  beam_size=beam_size,
28
- num_samples=4,
29
  seed=seed,
30
  num_workers=1,
 
 
31
  )
32
  model = MoLeR(configuration=config)
33
  samples = list(model.sample(number_of_samples))
34
 
35
- seed_mols = [] if scaffolds == "" else scaffolds.split(".")
36
- return draw_grid_generate(seed_mols, samples)
 
37
 
38
 
39
  if __name__ == "__main__":
@@ -67,7 +72,13 @@ if __name__ == "__main__":
67
  placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
68
  lines=1,
69
  ),
70
- gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beam_size"),
 
 
 
 
 
 
71
  gr.Slider(
72
  minimum=1, maximum=50, value=10, label="Number of samples", step=1
73
  ),
 
17
  def run_inference(
18
  algorithm_version: str,
19
  scaffolds: str,
20
+ seed_smiles: str,
21
  beam_size: int,
22
+ sigma: float,
23
  number_of_samples: int,
24
  seed: int,
25
  ):
 
27
  algorithm_version=algorithm_version,
28
  scaffolds=scaffolds,
29
  beam_size=beam_size,
30
+ num_samples=32,
31
  seed=seed,
32
  num_workers=1,
33
+ seed_smiles=seed_smiles,
34
+ sigma=sigma,
35
  )
36
  model = MoLeR(configuration=config)
37
  samples = list(model.sample(number_of_samples))
38
 
39
+ scaffold_list = [] if scaffolds == "" else scaffolds.split(".")
40
+ seed_list = [] if seed_smiles == "" else seed_smiles.split(".")
41
+ return draw_grid_generate(seed_list, scaffold_list, samples)
42
 
43
 
44
  if __name__ == "__main__":
 
72
  placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
73
  lines=1,
74
  ),
75
+ gr.Textbox(
76
+ label="Seed SMILES",
77
+ placeholder="O=C1C2=CC=C(C3=CC=CC=C3)C=C=C2OC2=CC=CC=C12",
78
+ lines=1,
79
+ ),
80
+ gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beams"),
81
+ gr.Slider(minimum=0.0, maximum=3.0, value=0.01, label="Sigma"),
82
  gr.Slider(
83
  minimum=1, maximum=50, value=10, label="Number of samples", step=1
84
  ),
model_cards/article.md CHANGED
@@ -2,12 +2,19 @@
2
 
3
  **Algorithm Version**: Which model checkpoint to use (trained on different datasets).
4
 
5
- **Scaffolds**: One or multiple scaffolds (or seed molecules), provided as '.'-separated SMILES. If empty, no scaffolds are used.
 
 
 
 
 
6
 
7
  **Number of samples**: How many samples should be generated (between 1 and 50).
8
 
9
  **Beam size**: Beam size used in beam search decoding (the higher the slower but better).
10
 
 
 
11
  **Seed**: The random seed used for initialization.
12
 
13
 
 
2
 
3
  **Algorithm Version**: Which model checkpoint to use (trained on different datasets).
4
 
5
+ **Scaffolds**: One or multiple scaffolds, provided as '.'-separated SMILES. If empty, no scaffolds are used. Note that this is a hard-constraint,
6
+ i.e., the scaffold will certainly be present in the generated molecule. If multiple scaffolds are given, they are paired with the seed SMILES
7
+ (if applicable) and every molecule will be guaranteed to contain exactly one scaffold.
8
+
9
+ **Seed SMILES**: One or multiple seed molecules, provided as '.'-separated SMILES. If empty, no scaffolds are used.
10
+ There's no guarantee for a seed SMILES (or a substructure of it) to be present in the generated molecule as it's merely used for decoder initialization.
11
 
12
  **Number of samples**: How many samples should be generated (between 1 and 50).
13
 
14
  **Beam size**: Beam size used in beam search decoding (the higher the slower but better).
15
 
16
+ **Sigma**: Variance of the Gaussian noise that is added to the latent code (before passing to the decoder).
17
+
18
  **Seed**: The random seed used for initialization.
19
 
20
 
model_cards/examples.csv CHANGED
@@ -1,5 +1,7 @@
1
- v0,,1,4,0
2
- v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,1,10,0
3
- v0,C12C=CC=NN1C(C#CC1=C(C)C=CC3C(NC4=CC(C(F)(F)F)=CC=C4)=NOC1=3)=CN=2.CCO,3,5,5
4
-
5
-
 
 
 
1
+ v0,,,1,0.0,4,0
2
+ v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,,1,0.0,10,1
3
+ v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,,1,0.3,10,2
4
+ v0,,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,1,0.2,10,3
5
+ v0,,C12C=CC=NN1C(C#CC1=C(C)C=CC3C(NC4=CC(C(F)(F)F)=CC=C4)=NOC1=3)=CN=2.CCO,3,0.2,5,5
6
+ v0,,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,1,0.5,10,9
7
+ v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,c1ccccc1,1,0.2,10,10
requirements.txt CHANGED
@@ -8,7 +8,7 @@ torch-sparse
8
  torch-geometric
9
  torchvision==0.13.1
10
  torchaudio==0.12.1
11
- gt4sd>=1.0.0
12
  molgx>=0.22.0a1
13
  diffusers==0.6.0
14
  molecule_generation
 
8
  torch-geometric
9
  torchvision==0.13.1
10
  torchaudio==0.12.1
11
+ gt4sd>=1.1.12
12
  molgx>=0.22.0a1
13
  diffusers==0.6.0
14
  molecule_generation
utils.py CHANGED
@@ -15,8 +15,9 @@ logger.addHandler(logging.NullHandler())
15
 
16
  def draw_grid_generate(
17
  seeds: List[str],
 
18
  samples: List[str],
19
- n_cols: int = 3,
20
  size=(140, 200),
21
  ) -> str:
22
  """
@@ -34,8 +35,9 @@ def draw_grid_generate(
34
  result = defaultdict(list)
35
  result.update(
36
  {
37
- "SMILES": seeds + samples,
38
  "Name": [f"Seed_{i}" for i in range(len(seeds))]
 
39
  + [f"Generated_{i}" for i in range(len(samples))],
40
  },
41
  )
 
15
 
16
  def draw_grid_generate(
17
  seeds: List[str],
18
+ scaffolds: List[str],
19
  samples: List[str],
20
+ n_cols: int = 5,
21
  size=(140, 200),
22
  ) -> str:
23
  """
 
35
  result = defaultdict(list)
36
  result.update(
37
  {
38
+ "SMILES": seeds + scaffolds + samples,
39
  "Name": [f"Seed_{i}" for i in range(len(seeds))]
40
+ + [f"Scaffold_{i}" for i in range(len(scaffolds))]
41
  + [f"Generated_{i}" for i in range(len(samples))],
42
  },
43
  )