saicharan2804 commited on
Commit
3c59b49
·
1 Parent(s): 614c0d4

Adding synthetic_complexity_score

Browse files
Files changed (1) hide show
  1. molgenevalmetric.py +129 -1
molgenevalmetric.py CHANGED
@@ -38,7 +38,135 @@ import pandas as pd
38
  from fcd_torch import FCD
39
  # from syba.syba import SybaClassifier
40
 
41
- from SCScore import SCScorer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  def get_mol(smiles_or_mol):
 
38
  from fcd_torch import FCD
39
  # from syba.syba import SybaClassifier
40
 
41
+ # from SCScore import SCScorer
42
+
43
+ import math, sys, random, os
44
+ import numpy as np
45
+ import time
46
+ import rdkit.Chem as Chem
47
+ import rdkit.Chem.AllChem as AllChem
48
+ import json
49
+ import gzip
50
+ import six
51
+
52
+ score_scale = 5.0
53
+ min_separation = 0.25
54
+
55
+ FP_len = 1024
56
+ FP_rad = 2
57
+
58
+ def sigmoid(x):
59
+ return 1 / (1 + math.exp(-x))
60
+
61
+ class SCScorer():
62
+ def __init__(self, score_scale=score_scale):
63
+ self.vars = []
64
+ self.score_scale = score_scale
65
+ self._restored = False
66
+
67
+ def restore(self, weight_path=os.path.join('model.ckpt-10654.as_numpy.json.gz'), FP_rad=FP_rad, FP_len=FP_len):
68
+ self.FP_len = FP_len; self.FP_rad = FP_rad
69
+ self._load_vars(weight_path)
70
+ # print('Restored variables from {}'.format(weight_path))
71
+
72
+ if 'uint8' in weight_path or 'counts' in weight_path:
73
+ def mol_to_fp(self, mol):
74
+ if mol is None:
75
+ return np.array((self.FP_len,), dtype=np.uint8)
76
+ fp = AllChem.GetMorganFingerprint(mol, self.FP_rad, useChirality=True) # uitnsparsevect
77
+ fp_folded = np.zeros((self.FP_len,), dtype=np.uint8)
78
+ for k, v in six.iteritems(fp.GetNonzeroElements()):
79
+ fp_folded[k % self.FP_len] += v
80
+ return np.array(fp_folded)
81
+ else:
82
+ def mol_to_fp(self, mol):
83
+ if mol is None:
84
+ return np.zeros((self.FP_len,), dtype=np.float32)
85
+ return np.array(AllChem.GetMorganFingerprintAsBitVect(mol, self.FP_rad, nBits=self.FP_len,
86
+ useChirality=True), dtype=np.bool_)
87
+ self.mol_to_fp = mol_to_fp
88
+
89
+ self._restored = True
90
+ return self
91
+
92
+ def smi_to_fp(self, smi):
93
+ if not smi:
94
+ return np.zeros((self.FP_len,), dtype=np.float32)
95
+ return self.mol_to_fp(self, Chem.MolFromSmiles(smi))
96
+
97
+ def apply(self, x):
98
+ if not self._restored:
99
+ raise ValueError('Must restore model weights!')
100
+ # Each pair of vars is a weight and bias term
101
+ for i in range(0, len(self.vars), 2):
102
+ last_layer = (i == len(self.vars)-2)
103
+ W = self.vars[i]
104
+ b = self.vars[i+1]
105
+ x = np.matmul(x, W) + b
106
+ if not last_layer:
107
+ x = x * (x > 0) # ReLU
108
+ x = 1 + (score_scale - 1) * sigmoid(x)
109
+ return x
110
+
111
+ def get_score_from_smi(self, smi='', v=False):
112
+ if not smi:
113
+ return ('', 0.)
114
+ fp = np.array((self.smi_to_fp(smi)), dtype=np.float32)
115
+ if sum(fp) == 0:
116
+ if v: print('Could not get fingerprint?')
117
+ cur_score = 0.
118
+ else:
119
+ # Run
120
+ cur_score = self.apply(fp)
121
+ if v: print('Score: {}'.format(cur_score))
122
+ mol = Chem.MolFromSmiles(smi)
123
+ if mol:
124
+ smi = Chem.MolToSmiles(mol, isomericSmiles=True, kekuleSmiles=True)
125
+ else:
126
+ smi = ''
127
+ return (smi, cur_score)
128
+
129
+ def get_avg_score(self, smis):
130
+ """
131
+ Compute the average score for a list of SMILES strings.
132
+
133
+ Args:
134
+ smis (list of str): A list of SMILES strings.
135
+
136
+ Returns:
137
+ float: The average score of the given SMILES strings.
138
+ """
139
+ if not smis: # Check if the list is empty
140
+ return 0.0
141
+
142
+ total_score = 0.0
143
+ valid_smiles_count = 0
144
+
145
+ for smi in smis:
146
+ _, score = self.get_score_from_smi(smi)
147
+ if score > 0: # Assuming only positive scores are valid
148
+ total_score += score
149
+ valid_smiles_count += 1
150
+
151
+ # Avoid division by zero
152
+ if valid_smiles_count == 0:
153
+ return 0.0
154
+ else:
155
+ return total_score / valid_smiles_count
156
+
157
+ def _load_vars(self, weight_path):
158
+ if weight_path.endswith('pickle'):
159
+ import pickle
160
+ with open(weight_path, 'rb') as fid:
161
+ self.vars = pickle.load(fid)
162
+ self.vars = [x.tolist() for x in self.vars]
163
+ elif weight_path.endswith('json.gz'):
164
+ with gzip.GzipFile(weight_path, 'r') as fin: # 4. gzip
165
+ json_bytes = fin.read() # 3. bytes (i.e. UTF-8)
166
+ json_str = json_bytes.decode('utf-8') # 2. string (i.e. JSON)
167
+ self.vars = json.loads(json_str)
168
+ self.vars = [np.array(x) for x in self.vars]
169
+
170
 
171
 
172
  def get_mol(smiles_or_mol):