doammii commited on
Commit
55d9b0c
·
verified ·
1 Parent(s): f9b9173

Add LlaMol codes

Browse files
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.out
2
+ debug
3
+ debug-gpu
4
+ outputs
5
+ chemiscope_gen.json
6
+ gen_smiles.txt
7
+ __pycache__
8
+ *.png
9
+ *.csv
10
+ *.json
11
+ # Byte-compiled / optimized / DLL files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+ data/opv/download
16
+ data/opv/opv.parquet
17
+ data/qm9_zinc250k_cep/zinc_properties.csv
18
+ data/qm9_zinc250k_cep/qm9_zinc250k_cep.parquet
19
+ data/zinc/zinc_complete/*/*.txt
20
+ !data/zinc/zinc_complete/download_zinc.sh
21
+ !data/zinc/zinc_complete/run_download.py
22
+ data/zinc/zinc_processed
23
+ data/zinc/zinc_processed.parquet
24
+ data/zinc/zinc_full.parquet
25
+ data/OrganiX13.parquet
26
+ .cache
27
+ out/plots
28
+ # C extensions
29
+ *.so
30
+
31
+ # Distribution / packaging
32
+ .Python
33
+ build/
34
+ develop-eggs/
35
+ dist/
36
+ downloads/
37
+ eggs/
38
+ .eggs/
39
+ lib/
40
+ lib64/
41
+ parts/
42
+ sdist/
43
+ var/
44
+ wheels/
45
+ share/python-wheels/
46
+ *.egg-info/
47
+ .installed.cfg
48
+ *.egg
49
+ MANIFEST
50
+
51
+ # PyInstaller
52
+ # Usually these files are written by a python script from a template
53
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
54
+ *.manifest
55
+ *.spec
56
+
57
+ # Installer logs
58
+ pip-log.txt
59
+ pip-delete-this-directory.txt
60
+
61
+ # Unit test / coverage reports
62
+ htmlcov/
63
+ .tox/
64
+ .nox/
65
+ .coverage
66
+ .coverage.*
67
+ .cache
68
+ nosetests.xml
69
+ coverage.xml
70
+ *.cover
71
+ *.py,cover
72
+ .hypothesis/
73
+ .pytest_cache/
74
+ cover/
75
+
76
+ # Translations
77
+ *.mo
78
+ *.pot
79
+
80
+ # Django stuff:
81
+ *.log
82
+ local_settings.py
83
+ db.sqlite3
84
+ db.sqlite3-journal
85
+
86
+ # Flask stuff:
87
+ instance/
88
+ .webassets-cache
89
+
90
+ # Scrapy stuff:
91
+ .scrapy
92
+
93
+ # Sphinx documentation
94
+ docs/_build/
95
+
96
+ # PyBuilder
97
+ .pybuilder/
98
+ target/
99
+
100
+ # Jupyter Notebook
101
+ .ipynb_checkpoints
102
+
103
+ # IPython
104
+ profile_default/
105
+ ipython_config.py
106
+
107
+ # pyenv
108
+ # For a library or package, you might want to ignore these files since the code is
109
+ # intended to run in multiple environments; otherwise, check them in:
110
+ # .python-version
111
+
112
+ # pipenv
113
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
114
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
115
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
116
+ # install all needed dependencies.
117
+ #Pipfile.lock
118
+
119
+ # poetry
120
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
121
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
122
+ # commonly ignored for libraries.
123
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
124
+ #poetry.lock
125
+
126
+ # pdm
127
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
128
+ #pdm.lock
129
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
130
+ # in version control.
131
+ # https://pdm.fming.dev/#use-with-ide
132
+ .pdm.toml
133
+
134
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
135
+ __pypackages__/
136
+
137
+ # Celery stuff
138
+ celerybeat-schedule
139
+ celerybeat.pid
140
+
141
+ # SageMath parsed files
142
+ *.sage.py
143
+
144
+ # Environments
145
+ .env
146
+ .venv
147
+ env/
148
+ venv/
149
+ ENV/
150
+ env.bak/
151
+ venv.bak/
152
+
153
+ # Spyder project settings
154
+ .spyderproject
155
+ .spyproject
156
+
157
+ # Rope project settings
158
+ .ropeproject
159
+
160
+ # mkdocs documentation
161
+ /site
162
+
163
+ # mypy
164
+ .mypy_cache/
165
+ .dmypy.json
166
+ dmypy.json
167
+
168
+ # Pyre type checker
169
+ .pyre/
170
+
171
+ # pytype static type analyzer
172
+ .pytype/
173
+
174
+ # Cython debug symbols
175
+ cython_debug/
176
+
177
+ # PyCharm
178
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
179
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
180
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
181
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
182
+ #.idea/
183
+ !assets/*.png
LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,5 +1,162 @@
1
- ---
2
- license: other
3
- license_name: attribution-noncommercial-share-alike4.0international
4
- license_link: https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Llamol
2
+
3
+ <p align="center">
4
+ <img src="assets/llamol.png" width="300" height="300" alt="LLamol">
5
+ </p>
6
+
7
+ This is the official repository for the paper ["LLamol: A Dynamic Multi-Conditional Generative Transformer for De Novo Molecular Design"](https://arxiv.org/abs/2311.14407).
8
+ In this repository are the weights for LLamol (`out/llama2-M-Full-RSS.pt`) and the dataset OrganiX13.
9
+
10
+ Image made with [Hotspot.ai](https://hotpot.ai/art-generator)
11
+ ## Installation
12
+ Install using Mamba to be fast: https://mamba.readthedocs.io/en/latest/micromamba-installation.html
13
+
14
+
15
+ ```bash
16
+ $ "${SHELL}" <(curl -L micro.mamba.pm/install.sh)
17
+ $ micromamba env create -f torch2-env.yaml
18
+ $ micromamba activate torch2-llamol
19
+ $ python sample.py
20
+ ```
21
+ # Download and preprocess the OrganiX13 dataset:
22
+ If you want to train with the full 13 Million dataset do the following steps. These are *not* necessary if you just want to use the model for inference:
23
+ 1. Download and preprocess the OPV dataset by running `/data/opv/prepare_opv.py`
24
+ 2. Download and preprocess the ZINC dataset by running `/data/zinc/zinc_complete/run_download.py` followed by `/data/zinc/convert_to_parquet.py`
25
+ (we recommend at least 16GB RAM for this)
26
+ 3. Download and preprocess the ZINC dataset by running `/data/qm9_zinc250k_cep/convert_to_parquet.py`
27
+
28
+ 4. Run `data/combine_all.py` to combine the dataset to `data/OrganiX13.parquet` (this can take a while, especially on the zinc dataset. In total it took ~2 hours when using my Laptop, which has 16 GB ram and an Intel i7 10th Gen)
29
+ 5. Run `preprocess_dataset.py` which should create the file `.cache/processed_dataset_None.pkl`
30
+
31
+ Now you can use that in the training of the model by specifing the file under the `processed_dataset_ckpt` of the training .yaml files.
32
+
33
+
34
+
35
+ # Interactive Demo
36
+
37
+ After installation you can play around with the model using the `demonstrator.ipynb` file. Just run all and scroll down to the last cell.
38
+ After a short time there should be a UI where you can play around with the model.
39
+
40
+
41
+ ## Training
42
+
43
+ First the env needs to be activated so:
44
+ ```bash
45
+ $ conda activate torch2-llamol # When installed with conda instead of micromamba
46
+ OR
47
+ $ micromamba activate torch2-llamol
48
+ ``````
49
+
50
+ To train locally you can run:
51
+ ```bash
52
+ # To set the config that you want to train with
53
+ $ python train.py train=llama2-M-Full-RSS
54
+ ```
55
+
56
+ Parameters can also be overriden by using the following, for example:
57
+ ```bash
58
+ $ python train.py train=llama2-M-Full-RSS train.model.dim=1024
59
+ ```
60
+ For more information look at [Hydra](https://hydra.cc/docs/1.3/intro/)
61
+
62
+ To start a job on a SLURM cluster use the following script:
63
+ ```bash
64
+ $ sbatch trainLLamaMol.sh
65
+ ``````
66
+
67
+ ## Training Multi-GPU on 1 Node with multiple GPUS (nproc_per_node)
68
+ ```bash
69
+ torchrun --standalone --max_restarts=3 --nnodes=1 --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint="$localhost:12345" train.py train=llama2-M-Full-RSS > "train_runs/run_MultiGPU.out"
70
+ ```
71
+ ## Training Multi-GPU on 1 Node with multiple GPUS on a Cluster
72
+ Currently there is only one script to train with DDP. To change the number of GPUS in that script you have to change the bash script itself.
73
+ TODO: Make it more dynamic, with allowing console commands to change the number of GPUS etc.
74
+ ```bash
75
+ sbatch trainLLamaMolDDPSingleNode.sh
76
+ ```
77
+
78
+ ## Sampling
79
+ Sampling can be changed by the OPTIONAL parameters as shown below.
80
+ ```bash
81
+ $ python sample.py --help
82
+
83
+ $ python sample.py --num_samples 2000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --seed 4312 --context_cols logp sascore mol_weight --temperature 0.8
84
+ ```
85
+
86
+
87
+ ## Using own dataset
88
+
89
+ Use the `preprocess_dataset.py` file to tokenize the dataset. The dataset should be either in the parquet or csv format.
90
+ The SMILES used for training should be in the `smiles` column in the dataset. All conditions, should be given to the pretokenize function.
91
+ After the preprocessing is done a file should be stored in the .cache directory with the name `processed_dataset_{limit}.pkl`.
92
+ You could also rename this file to not overwrite it every time you run the preprocessing.
93
+
94
+ The `.cache/processed_dataset_{limit}.pkl` can then be set in the `config/train/llama2-M-Full-RSS.yaml file` to change the training with the new dataset in the `processed_dataset_ckpt` field in the yaml file.
95
+
96
+ # Training methods
97
+
98
+ The training method we used and described in the paper is here called RSS for "Random Smiles Sampling" which was the method then described in the "Stochastic Context Learning" as taking a random subsequence from the current SMILES while training and feeding that into the model as a token sequence condition. So the model we used in the paper was the `out/llama2-M-Full-RSS.pt`.
99
+
100
+ We also tried other approached for including the token sequence.
101
+ One was using murcko scaffolds as they were used in the MolGPT paper, but this approach did not yield great results for our purposes.
102
+ The other was using BRICKS decomposition, which also did not yield very good results.
103
+
104
+ The different methods are implemented in the `fragment_creator.py` file.
105
+ Each of the models were trained with their respective configurations in the `config/train` folder.
106
+
107
+ # Thanks
108
+
109
+
110
+ - [Karpathy](https://github.com/karpathy/llama2.c) for the implementation of the Llama 2 architecture and training code
111
+
112
+ - [DeepChem](https://github.com/deepchem/deepchem) for the SmilesTokenizer
113
+
114
+ - [TorchDrug](https://github.com/DeepGraphLearning/torchdrug/) for the downloads scripts for the OPV and CEP datasets
115
+
116
+ - Zinc 15 dataset (Teague Sterling and John J. Irwin. ZINC 15 – ligand discovery for everyone. Journal of Chemical Information
117
+ and Modeling, 55(11):2324–2337, November 2015.)
118
+
119
+ - QM9 dataset (
120
+ Raghunathan Ramakrishnan, Pavlo O. Dral, Matthias Rupp, and O. Anatole von Lilienfeld. Quantum chemistry
121
+ structures and properties of 134 kilo molecules. Scientific Data, 1(1), aug 2014.)
122
+
123
+ - PC9 dataset (Marta Glavatskikh, Jules Leguy, Gilles Hunault, Thomas Cauchy, and Benoit Da Mota. Dataset’s chemical
124
+ diversity limits the generalizability of machine learning predictions. Journal of Cheminformatics, 11(1), nov 2019)
125
+
126
+ - ZINC 250k (Rafael Gó mez-Bombarelli, Jennifer N. Wei, David Duvenaud, José Miguel Hernández-Lobato, Benjamín
127
+ Sánchez-Lengeling, Dennis Sheberla, Jorge Aguilera-Iparraguirre, Timothy D. Hirzel, Ryan P. Adams, and Alán
128
+ Aspuru-Guzik. Automatic chemical design using a data-driven continuous representation of molecules. ACS
129
+ Central Science, 4(2):268–276, jan 2018.)
130
+
131
+ - RedDB (Elif Sorkun, Qi Zhang, Abhishek Khetan, Murat Cihan Sorkun, and Süleyman Er. RedDB, a computational
132
+ database of electroactive molecules for aqueous redox flow batteries. Scientific Data, 9(1), nov 2022.)
133
+
134
+ - OPV (Peter C. St. John, Caleb Phillips, Travis W. Kemper, A. Nolan Wilson, Yanfei Guan, Michael F. Crowley, Mark R.
135
+ Nimlos, and Ross E. Larsen. Message-passing neural networks for high-throughput polymer screening. The
136
+ Journal of Chemical Physics, 150(23):234111, jun 2019.)
137
+
138
+ - PubchemQC 2020 (Maho Nakata, Tomomi Shimazaki, Masatomo Hashimoto, and Toshiyuki Maeda. PubChemQC PM6: Data sets
139
+ of 221 million molecules with optimized molecular geometries and electronic properties. Journal of Chemical
140
+ Information and Modeling, 60(12):5891–5899, oct 2020.)
141
+
142
+ - PubchemQC 2017 (Maho Nakata and Tomomi Shimazaki. PubChemQC project: A large-scale first-principles electronic structure
143
+ database for data-driven chemistry. Journal of Chemical Information and Modeling, 57(6):1300–1308, may 2017.)
144
+
145
+ - CEP (Johannes Hachmann, Roberto Olivares-Amaya, Sule Atahan-Evrenk, Carlos Amador-Bedolla, Roel S. Sánchez-
146
+ Carrera, Aryeh Gold-Parker, Leslie Vogt, Anna M. Brockway, and Alán Aspuru-Guzik. The Harvard clean energy
147
+ project: Large-scale computational screening and design of organic photovoltaics on the world community grid.
148
+ The Journal of Physical Chemistry Letters, 2(17):2241–2251, aug 2011.) subset ( David Duvenaud, Dougal Maclaurin, Jorge Aguilera-Iparraguirre, Rafael Gómez-Bombarelli, Timothy Hirzel,
149
+ Alán Aspuru-Guzik, and Ryan P. Adams. Convolutional networks on graphs for learning molecular fingerprints,
150
+ 2015.)
151
+ - ChEMBL (James Blackshaw, Anna Gaulton, A. Patrícia Bento, Marleen De Veij, David Mendez Lopez, Nicolas Bosc, Juan
152
+ Felipe Mosquera Morales, María Paula Margariños, Andrew Leach, Emma Manners, Barbara Zdrazil, Harris
153
+ Ioannidis, Fiona Hunter, Eloy Félix, and Ricardo Arcila Toro. CHEMBL database release 31, September 2009.)
154
+
155
+ # Funding disclaimer
156
+
157
+ This project has received funding from the European Union’s Horizon 2020 research and innovation programme under Grant Agreement no. 875489.
158
+
159
+ This website reflects only the author’s view. The funding agency is not responsible for any use made of the information it contains.
160
+
161
+ # License
162
+ <p xmlns:cc="http://creativecommons.org/ns#" xmlns:dct="http://purl.org/dc/terms/"><span property="dct:title">LLamol is licensed under <a href="http://creativecommons.org/licenses/by-nc-sa/4.0/?ref=chooser-v1" target="_blank" rel="license noopener noreferrer" style="display:inline-block;">CC BY-NC-SA 4.0<img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/cc.svg?ref=chooser-v1"><img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/by.svg?ref=chooser-v1"><img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/nc.svg?ref=chooser-v1"><img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/sa.svg?ref=chooser-v1"></a></p>
assets/llamol.png ADDED
config/config.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ defaults:
2
+ - train: "llama2-Debug"
config/train/llama2-Debug.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ io:
2
+ # I/O
3
+ out_dir : "debug"
4
+ eval_interval : 10
5
+ log_interval : 10
6
+ eval_iters : 5
7
+ eval_only : false # if True, script exits right after the first eval
8
+ always_save_checkpoint : true # if True, always save a checkpoint after each eval
9
+ init_from : "scratch" # 'scratch' or 'resume'
10
+ resume_when_snapshot_available: false
11
+
12
+ loader:
13
+ batch_size : 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
14
+ max_seq_len : 768
15
+ dataset : "smiles"
16
+ processed_dataset_ckpt : "processed_dataset_500000.pkl"
17
+ fragment_creator : "rss"
18
+
19
+ model:
20
+ dim : 32
21
+ n_layers : 1
22
+ n_heads : 1
23
+ multiple_of : 16
24
+ dropout : 0.1
25
+
26
+ context:
27
+ context_keys: ["logp", "sascore", "mol_weight"]
28
+ context_dims : [1,1,1]
29
+
30
+ optimizer:
31
+ gradient_accumulation_steps : 4 # used to simulate larger batch sizes
32
+ learning_rate : 1e-4 # max learning rate
33
+ max_iters : 20 # total number of training iterations
34
+ weight_decay : 1e-1
35
+ beta1 : 0.9
36
+ beta2 : 0.95
37
+ grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
38
+ # learning rate decay settings
39
+ decay_lr : true # whether to decay the learning rate
40
+ warmup_iters : 10 # how many steps to warm up for
41
+ lr_decay_iters : 100 # should be ~= max_iters per Chinchilla
42
+ min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
43
+
44
+ dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
45
+ compile : false # Use torch.compile, but in my test this is really slow
46
+ label : "llama2-Debug"
47
+ profile : false
config/train/llama2-DebugGPU.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ io:
2
+ # I/O
3
+ out_dir : "debug-gpu"
4
+ eval_interval : 10
5
+ log_interval : 10
6
+ eval_iters : 5
7
+ eval_only : false # if True, script exits right after the first eval
8
+ always_save_checkpoint : true # if True, always save a checkpoint after each eval
9
+ init_from : "scratch" # 'scratch' or 'resume'
10
+ resume_when_snapshot_available: false
11
+
12
+ loader:
13
+ batch_size : 256 # if gradient_accumulation_steps > 1, this is the micro-batch size
14
+ max_seq_len : 256
15
+ dataset : "smiles"
16
+ processed_dataset_ckpt : "processed_dataset_500000.pkl"
17
+
18
+ model:
19
+ dim : 256
20
+ n_layers : 8
21
+ n_heads : 8
22
+ multiple_of : 128
23
+ dropout : 0.1
24
+
25
+ context:
26
+ context_keys: ["logp", "sascore", "mol_weight"]
27
+ context_dims : [1,1,1]
28
+
29
+ optimizer:
30
+ gradient_accumulation_steps : 4 # used to simulate larger batch sizes
31
+ learning_rate : 1e-4 # max learning rate
32
+ max_iters : 25 # total number of training iterations
33
+ weight_decay : 1e-1
34
+ beta1 : 0.9
35
+ beta2 : 0.95
36
+ grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
37
+ # learning rate decay settings
38
+ decay_lr : true # whether to decay the learning rate
39
+ warmup_iters : 10 # how many steps to warm up for
40
+ lr_decay_iters : 100 # should be ~= max_iters per Chinchilla
41
+ min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
42
+
43
+ dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
44
+ compile : false # Use torch.compile, but in my test this is really slow
45
+ label : "llama2-Debug"
46
+ profile: true # Profile the run
config/train/llama2-M-Full-BRICKS.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ io:
2
+ # I/O
3
+ out_dir : "out"
4
+ eval_interval : 500
5
+ log_interval : 10
6
+ eval_iters : 10
7
+ eval_only : false # if True, script exits right after the first eval
8
+ always_save_checkpoint : false # if True, always save a checkpoint after each eval
9
+ init_from : "scratch" # 'scratch' or 'resume'
10
+ resume_when_snapshot_available: true
11
+
12
+ loader:
13
+ batch_size : 384 # if gradient_accumulation_steps > 1, this is the micro-batch size
14
+ max_seq_len : 768
15
+ dataset : "smiles"
16
+ processed_dataset_ckpt : "processed_dataset_None.pkl"
17
+ fragment_creator : "bricks"
18
+
19
+ model:
20
+ dim : 256
21
+ n_layers : 8
22
+ n_heads : 8
23
+ multiple_of : 128
24
+ dropout : 0.1
25
+
26
+ context:
27
+ context_keys: ["logp", "sascore", "mol_weight"]
28
+ context_dims : [1,1,1]
29
+
30
+ optimizer:
31
+ gradient_accumulation_steps : 4 # used to simulate larger batch sizes
32
+ learning_rate : 1e-4 # max learning rate
33
+ max_iters : 100000 # total number of training iterations
34
+ weight_decay : 1e-1
35
+ beta1 : 0.9
36
+ beta2 : 0.95
37
+ grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
38
+ # learning rate decay settings
39
+ decay_lr : true # whether to decay the learning rate
40
+ warmup_iters : 1000 # how many steps to warm up for
41
+ lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla
42
+ min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
43
+
44
+ dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
45
+ compile : false # Use torch.compile, but in my test this is really slow
46
+ label : "llama2-M-Full-BRICKS"
config/train/llama2-M-Full-RSS.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ io:
2
+ # I/O
3
+ out_dir : "out"
4
+ eval_interval : 500
5
+ log_interval : 10
6
+ eval_iters : 10
7
+ eval_only : false # if True, script exits right after the first eval
8
+ always_save_checkpoint : false # if True, always save a checkpoint after each eval
9
+ init_from : "scratch" # 'scratch' or 'resume'
10
+ resume_when_snapshot_available: true # resume the training always, when the `snapshot_` is available in the out/ folder
11
+
12
+ loader:
13
+ batch_size : 256 # if gradient_accumulation_steps > 1, this is the micro-batch size
14
+ max_seq_len : 256 # the maximum sequence length we want to use in the training data.
15
+ dataset : "smiles"
16
+ processed_dataset_ckpt : "processed_dataset_None.pkl"
17
+ fragment_creator : "rss" # the method we want to use to train with the token_sequence
18
+
19
+ model:
20
+ dim : 384
21
+ n_layers : 8
22
+ n_heads : 8
23
+ multiple_of : 128
24
+ dropout : 0.1
25
+
26
+ context:
27
+ context_keys: ["logp", "sascore", "mol_weight"]
28
+ context_dims : [1,1,1]
29
+
30
+ optimizer:
31
+ gradient_accumulation_steps : 4 # used to simulate larger batch sizes
32
+ learning_rate : 1e-4 # max learning rate
33
+ max_iters : 100000 # total number of training iterations
34
+ weight_decay : 1e-1
35
+ beta1 : 0.9
36
+ beta2 : 0.95
37
+ grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
38
+ # learning rate decay settings
39
+ decay_lr : true # whether to decay the learning rate
40
+ warmup_iters : 1000 # how many steps to warm up for
41
+ lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla
42
+ min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
43
+
44
+ dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
45
+ compile : false # Use torch.compile, but in my test this is really slow
46
+ label : "llama2-M-Full-RSS" # the name of the output file / model
config/train/llama2-M-Full.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ io:
2
+ # I/O
3
+ out_dir : "out"
4
+ eval_interval : 500
5
+ log_interval : 10
6
+ eval_iters : 10
7
+ eval_only : false # if True, script exits right after the first eval
8
+ always_save_checkpoint : false # if True, always save a checkpoint after each eval
9
+ init_from : "scratch" # 'scratch' or 'resume'
10
+ resume_when_snapshot_available: true
11
+
12
+ loader:
13
+ batch_size : 384 # if gradient_accumulation_steps > 1, this is the micro-batch size
14
+ max_seq_len : 768
15
+ dataset : "smiles"
16
+ processed_dataset_ckpt : "processed_dataset_None.pkl"
17
+ fragment_creator : null
18
+
19
+ model:
20
+ dim : 256
21
+ n_layers : 8
22
+ n_heads : 8
23
+ multiple_of : 128
24
+ dropout : 0.1
25
+
26
+ context:
27
+ context_keys: ["logp", "sascore", "mol_weight"]
28
+ context_dims : [1,1,1]
29
+
30
+ optimizer:
31
+ gradient_accumulation_steps : 4 # used to simulate larger batch sizes
32
+ learning_rate : 1e-4 # max learning rate
33
+ max_iters : 100000 # total number of training iterations
34
+ weight_decay : 1e-1
35
+ beta1 : 0.9
36
+ beta2 : 0.95
37
+ grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
38
+ # learning rate decay settings
39
+ decay_lr : true # whether to decay the learning rate
40
+ warmup_iters : 1000 # how many steps to warm up for
41
+ lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla
42
+ min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
43
+
44
+ dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
45
+ compile : false # Use torch.compile, but in my test this is really slow
46
+ label : "llama2-M-Full"
data/Full_PC9_GAP.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e1c1932284e5987ff997675b3f8ad2a8763c4dc864315e78a774841fb6b6791
3
+ size 38893336
data/RedDB_Full.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:543e98ba1b622a2a949a3818d047daa478658d2d91923a291907a2d9c8c886bd
3
+ size 1024066
data/chembl_log_sascore.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30d04f6f1f01caec6164d85b23ba1282dfe63ec1b245e4c358aa216831c32ee8
3
+ size 99582099
data/combine_all.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ from rdkit import Chem
5
+ from rdkit.Chem import Descriptors
6
+ import multiprocessing
7
+
8
+ from rdkit import Chem
9
+ from rdkit.Chem import RDConfig
10
+ import os
11
+ import sys
12
+ sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
13
+ # now you can import sascore!
14
+ import sascorer
15
+
16
+ np.random.seed(42)
17
+
18
+ def calcLogPIfMol(smi):
19
+ m = Chem.MolFromSmiles(smi)
20
+ if m is not None:
21
+ return Descriptors.MolLogP(m)
22
+ else:
23
+ return None
24
+
25
+ def calcMol(smi):
26
+ return Chem.MolFromSmiles(smi)
27
+
28
+ def calcMolWeight(smi):
29
+ mol = Chem.MolFromSmiles(smi)
30
+ return Descriptors.ExactMolWt(mol)
31
+
32
+ def calcSascore(smi):
33
+ mol = Chem.MolFromSmiles(smi)
34
+
35
+ return sascorer.calculateScore(mol)
36
+
37
+ def calculateValues(smi: pd.Series):
38
+
39
+
40
+ with multiprocessing.Pool(8) as pool:
41
+ print("Starting logps")
42
+ logps = pool.map(calcLogPIfMol, smi)
43
+ print("Done logps")
44
+ valid_mols = ~pd.isna(logps)
45
+ logps = pd.Series(logps)[valid_mols]
46
+ smi = pd.Series(smi)[valid_mols]
47
+ logps.reset_index(drop=True,inplace=True)
48
+ smi.reset_index(drop=True,inplace=True)
49
+ print("Starting mol weights")
50
+ mol_weights = pool.map(calcMolWeight, smi)
51
+ print("Done mol weights")
52
+ print("Starting sascores")
53
+ sascores = pool.map(calcSascore, smi)
54
+ print("Done sascores")
55
+
56
+ return smi, logps, mol_weights,sascores
57
+
58
+ def calculateProperties(df):
59
+
60
+ smi, logps, mol_weights,sascores = calculateValues(df["smiles"])
61
+ out_df = pd.DataFrame({"smiles": smi, "logp":logps, "mol_weight":mol_weights, "sascore":sascores })
62
+
63
+ return out_df
64
+
65
+ if __name__ == "__main__":
66
+
67
+ cwd = os.path.dirname(__file__)
68
+
69
+ print("df_pc9")
70
+ df_pc9 = pd.read_parquet(os.path.join(cwd, "Full_PC9_GAP.parquet"))
71
+ df_pc9 = calculateProperties(df_pc9)
72
+
73
+
74
+ print("df_zinc_full")
75
+
76
+ df_zinc_full = pd.read_parquet(
77
+ os.path.join(cwd, "zinc", "zinc_processed.parquet")
78
+ )
79
+ df_zinc_full = df_zinc_full.sample(n=5_000_000)
80
+ df_zinc_full = calculateProperties(df_zinc_full)
81
+
82
+
83
+ print("df_zinc_qm9")
84
+ df_zinc_qm9 = pd.read_parquet(os.path.join(cwd,"qm9_zinc250k_cep", "qm9_zinc250_cep.parquet"))
85
+ df_zinc_qm9 = calculateProperties(df_zinc_qm9)
86
+
87
+ print("df_opv")
88
+ df_opv = pd.read_parquet(os.path.join(cwd,"opv", "opv.parquet"))
89
+ df_opv = calculateProperties(df_opv)
90
+
91
+
92
+ print("df_reddb")
93
+ # Source: https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/F3QFSQ
94
+ df_reddb = pd.read_parquet(os.path.join(cwd,"RedDB_Full.parquet"))
95
+ df_reddb = calculateProperties(df_reddb)
96
+
97
+ print("df_chembl")
98
+ df_chembl = pd.read_parquet(
99
+ os.path.join(cwd, "chembl_log_sascore.parquet")
100
+ )
101
+ df_chembl = calculateProperties(df_chembl)
102
+
103
+
104
+ print("df_pubchemqc_2017")
105
+ df_pubchemqc_2017 = pd.read_parquet(
106
+ os.path.join(cwd, "pubchemqc_energy.parquet")
107
+ )
108
+ df_pubchemqc_2017 = calculateProperties(df_pubchemqc_2017)
109
+
110
+
111
+ print("df_pubchemqc_2020")
112
+
113
+ df_pubchemqc_2020 = pd.read_parquet(
114
+ os.path.join(cwd, "pubchemqc2020_energy.parquet")
115
+ )
116
+ df_pubchemqc_2020 = calculateProperties(df_pubchemqc_2020)
117
+
118
+
119
+
120
+ df_list = [
121
+ df_zinc_qm9,
122
+ df_opv,
123
+ df_pubchemqc_2017,
124
+ df_pubchemqc_2020,
125
+ df_zinc_full,
126
+ df_reddb,
127
+ df_pc9,
128
+ df_chembl,
129
+ ]
130
+
131
+ print(f"ZINC QM9 {len(df_zinc_qm9)}")
132
+ print(f"df_opv {len(df_opv)}")
133
+ print(f"df_pubchemqc_2017 {len(df_pubchemqc_2017)}")
134
+ print(f"df_pubchemqc_2020 {len(df_pubchemqc_2020)}")
135
+ print(f"df_zinc_full {len(df_zinc_full)}")
136
+ print(f"df_reddb {len(df_reddb)}")
137
+ print(f"df_pc9 {len(df_pc9)}")
138
+ print(f"df_chembl {len(df_chembl)}")
139
+
140
+
141
+
142
+
143
+
144
+ all_columns = [
145
+ "smiles",
146
+ "logp",
147
+ "sascore",
148
+ "mol_weight"
149
+ ] # set([*df_zinc_qm9.columns.tolist(),*df_pubchemqc_2017.columns.tolist(),*df_pubchemqc_2020.columns.tolist(),*df_zinc_full.columns.tolist()] )
150
+ print("concatenting")
151
+ df = pd.concat(
152
+ df_list, axis=0, ignore_index=True
153
+ ) # pd.DataFrame(columns=all_columns)
154
+ df = df[all_columns] # .fillna(0)
155
+ # df = df.sample(n=7_500_000)
156
+ df.reset_index(drop=True, inplace=True)
157
+ df["mol_weight"] = df["mol_weight"] / 100.0
158
+
159
+ print(df.head())
160
+ print("saving")
161
+ print("Combined len:", len(df))
162
+ df.to_parquet(
163
+ os.path.join(cwd, "OrganiX13.parquet")
164
+ )
data/opv/prepare_opv.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import struct
4
+ import logging
5
+ from tqdm import tqdm
6
+ import csv
7
+ from collections import defaultdict
8
+ import pandas as pd
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Taken from here https://torchdrug.ai/docs/_modules/torchdrug/utils/file.html#download
13
+ def download(url, path, save_file=None, md5=None):
14
+ """
15
+ Download a file from the specified url.
16
+ Skip the downloading step if there exists a file satisfying the given MD5.
17
+
18
+ Parameters:
19
+ url (str): URL to download
20
+ path (str): path to store the downloaded file
21
+ save_file (str, optional): name of save file. If not specified, infer the file name from the URL.
22
+ md5 (str, optional): MD5 of the file
23
+ """
24
+ from six.moves.urllib.request import urlretrieve
25
+
26
+ if save_file is None:
27
+ save_file = os.path.basename(url)
28
+ if "?" in save_file:
29
+ save_file = save_file[:save_file.find("?")]
30
+ save_file = os.path.join(path, save_file)
31
+
32
+ if not os.path.exists(save_file) or compute_md5(save_file) != md5:
33
+ logger.info("Downloading %s to %s" % (url, save_file))
34
+ urlretrieve(url, save_file)
35
+ return save_file
36
+
37
+
38
+
39
+ def smart_open(file_name, mode="rb"):
40
+ """
41
+ Open a regular file or a zipped file.
42
+
43
+ This function can be used as drop-in replacement of the builtin function `open()`.
44
+
45
+ Parameters:
46
+ file_name (str): file name
47
+ mode (str, optional): open mode for the file stream
48
+ """
49
+ import bz2
50
+ import gzip
51
+
52
+ extension = os.path.splitext(file_name)[1]
53
+ if extension == '.bz2':
54
+ return bz2.BZ2File(file_name, mode)
55
+ elif extension == '.gz':
56
+ return gzip.GzipFile(file_name, mode)
57
+ else:
58
+ return open(file_name, mode)
59
+
60
+
61
+ def extract(zip_file, member=None):
62
+ """
63
+ Extract files from a zip file. Currently, ``zip``, ``gz``, ``tar.gz``, ``tar`` file types are supported.
64
+
65
+ Parameters:
66
+ zip_file (str): file name
67
+ member (str, optional): extract specific member from the zip file.
68
+ If not specified, extract all members.
69
+ """
70
+ import gzip
71
+ import shutil
72
+ import zipfile
73
+ import tarfile
74
+
75
+ zip_name, extension = os.path.splitext(zip_file)
76
+ if zip_name.endswith(".tar"):
77
+ extension = ".tar" + extension
78
+ zip_name = zip_name[:-4]
79
+ save_path = os.path.dirname(zip_file)
80
+
81
+ if extension == ".gz":
82
+ member = os.path.basename(zip_name)
83
+ members = [member]
84
+ save_files = [os.path.join(save_path, member)]
85
+ for _member, save_file in zip(members, save_files):
86
+ with open(zip_file, "rb") as fin:
87
+ fin.seek(-4, 2)
88
+ file_size = struct.unpack("<I", fin.read())[0]
89
+ with gzip.open(zip_file, "rb") as fin:
90
+ if not os.path.exists(save_file) or file_size != os.path.getsize(save_file):
91
+ logger.info("Extracting %s to %s" % (zip_file, save_file))
92
+ with open(save_file, "wb") as fout:
93
+ shutil.copyfileobj(fin, fout)
94
+ elif extension in [".tar.gz", ".tgz", ".tar"]:
95
+ tar = tarfile.open(zip_file, "r")
96
+ if member is not None:
97
+ members = [member]
98
+ save_files = [os.path.join(save_path, os.path.basename(member))]
99
+ logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
100
+ else:
101
+ members = tar.getnames()
102
+ save_files = [os.path.join(save_path, _member) for _member in members]
103
+ logger.info("Extracting %s to %s" % (zip_file, save_path))
104
+ for _member, save_file in zip(members, save_files):
105
+ if tar.getmember(_member).isdir():
106
+ os.makedirs(save_file, exist_ok=True)
107
+ continue
108
+ os.makedirs(os.path.dirname(save_file), exist_ok=True)
109
+ if not os.path.exists(save_file) or tar.getmember(_member).size != os.path.getsize(save_file):
110
+ with tar.extractfile(_member) as fin, open(save_file, "wb") as fout:
111
+ shutil.copyfileobj(fin, fout)
112
+ elif extension == ".zip":
113
+ zipped = zipfile.ZipFile(zip_file)
114
+ if member is not None:
115
+ members = [member]
116
+ save_files = [os.path.join(save_path, os.path.basename(member))]
117
+ logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
118
+ else:
119
+ members = zipped.namelist()
120
+ save_files = [os.path.join(save_path, _member) for _member in members]
121
+ logger.info("Extracting %s to %s" % (zip_file, save_path))
122
+ for _member, save_file in zip(members, save_files):
123
+ if zipped.getinfo(_member).is_dir():
124
+ os.makedirs(save_file, exist_ok=True)
125
+ continue
126
+ os.makedirs(os.path.dirname(save_file), exist_ok=True)
127
+ if not os.path.exists(save_file) or zipped.getinfo(_member).file_size != os.path.getsize(save_file):
128
+ with zipped.open(_member, "r") as fin, open(save_file, "wb") as fout:
129
+ shutil.copyfileobj(fin, fout)
130
+ else:
131
+ raise ValueError("Unknown file extension `%s`" % extension)
132
+
133
+ if len(save_files) == 1:
134
+ return save_files[0]
135
+ else:
136
+ return save_path
137
+
138
+
139
+
140
+ def compute_md5(file_name, chunk_size=65536):
141
+ """
142
+ Compute MD5 of the file.
143
+
144
+ Parameters:
145
+ file_name (str): file name
146
+ chunk_size (int, optional): chunk size for reading large files
147
+ """
148
+ import hashlib
149
+
150
+ md5 = hashlib.md5()
151
+ with open(file_name, "rb") as fin:
152
+ chunk = fin.read(chunk_size)
153
+ while chunk:
154
+ md5.update(chunk)
155
+ chunk = fin.read(chunk_size)
156
+ return md5.hexdigest()
157
+
158
+
159
+
160
+ def get_line_count(file_name, chunk_size=8192*1024):
161
+ """
162
+ Get the number of lines in a file.
163
+
164
+ Parameters:
165
+ file_name (str): file name
166
+ chunk_size (int, optional): chunk size for reading large files
167
+ """
168
+ count = 0
169
+ with open(file_name, "rb") as fin:
170
+ chunk = fin.read(chunk_size)
171
+ while chunk:
172
+ count += chunk.count(b"\n")
173
+ chunk = fin.read(chunk_size)
174
+ return count
175
+
176
+
177
+ class OPV:
178
+ """
179
+ Quantum mechanical calculations on organic photovoltaic candidate molecules.
180
+
181
+ Statistics:
182
+ - #Molecule: 94,576
183
+ - #Regression task: 8
184
+
185
+ Parameters:
186
+ path (str): path to store the dataset
187
+ verbose (int, optional): output verbose level
188
+ **kwargs
189
+ """
190
+
191
+ train_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \
192
+ "b69cf9a5-e7e0-405b-88cb-40df8007242e"
193
+ valid_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \
194
+ "1c8e7379-3071-4360-ba8e-0c6481c33d2c"
195
+ test_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \
196
+ "4ef40592-0080-4f00-9bb7-34b25f94962a"
197
+ train_md5 = "16e439b7411ea0a8d3a56ba4802b61b1"
198
+ valid_md5 = "3aa2ac62015932ca84661feb5d29adda"
199
+ test_md5 = "bad072224f0755478f0729476ca99a33"
200
+ target_fields = ["gap", "homo", "lumo", "spectral_overlap", "gap_extrapolated", "homo_extrapolated",
201
+ "lumo_extrapolated", "optical_lumo_extrapolated"]
202
+
203
+ def read_csv(self, csv_file, smiles_field="smiles", target_fields=None, verbose=0):
204
+ if target_fields is not None:
205
+ target_fields = set(target_fields)
206
+
207
+ with open(csv_file, "r") as fin:
208
+ reader = csv.reader(fin)
209
+ if verbose:
210
+ reader = iter(tqdm(reader, "Loading %s" % csv_file, get_line_count(csv_file)))
211
+ fields = next(reader)
212
+ smiles = []
213
+ targets = defaultdict(list)
214
+ for i, values in enumerate(reader):
215
+ if not any(values):
216
+ continue
217
+ if smiles_field is None:
218
+ smiles.append("")
219
+ for field, value in zip(fields, values):
220
+ if field == smiles_field:
221
+ smiles.append(value)
222
+ elif target_fields is None or field in target_fields:
223
+ pass
224
+ # value = eval(value)
225
+ # if value == "":
226
+ # value = math.nan
227
+ # targets[field].append(value)
228
+
229
+ return smiles, targets
230
+
231
+ def __init__(self, path, verbose=1, **kwargs):
232
+ path = os.path.expanduser(path)
233
+ if not os.path.exists(path):
234
+ os.makedirs(path)
235
+ self.path = path
236
+
237
+ train_zip_file = download(self.train_url, path, save_file="mol_train.csv.gz", md5=self.train_md5)
238
+ valid_zip_file = download(self.valid_url, path, save_file="mol_valid.csv.gz", md5=self.valid_md5)
239
+ test_zip_file = download(self.test_url, path, save_file="mol_test.csv.gz", md5=self.test_md5)
240
+ train_file = extract(train_zip_file)
241
+ valid_file = extract(valid_zip_file)
242
+ test_file = extract(test_zip_file)
243
+
244
+ train_smiles, train_targets = self.read_csv(train_file, smiles_field="smile", target_fields=self.target_fields)
245
+ valid_smiles, valid_targets = self.read_csv(valid_file, smiles_field="smile", target_fields=self.target_fields)
246
+ test_smiles, test_targets = self.read_csv(test_file, smiles_field="smile", target_fields=self.target_fields)
247
+ self.num_train = len(train_smiles)
248
+ self.num_valid = len(valid_smiles)
249
+ self.num_test = len(test_smiles)
250
+
251
+ smiles = train_smiles + valid_smiles + test_smiles
252
+ targets = {k: train_targets[k] + valid_targets[k] + test_targets[k] for k in train_targets}
253
+
254
+ # self.load_smiles(smiles, targets, verbose=verbose, **kwargs)
255
+ print(smiles[:10])
256
+ df_out = pd.DataFrame({"smiles": smiles})
257
+ df_out.to_parquet(os.path.join(os.path.dirname(__file__), "opv.parquet"))
258
+
259
+
260
+ if __name__ == "__main__":
261
+ logging.basicConfig(level=logging.INFO)
262
+ cwd = os.path.join(os.path.dirname(__file__), "download")
263
+ os.makedirs(cwd,exist_ok=True)
264
+ d = OPV(cwd)
265
+
data/pubchemqc2020_energy.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d5ef9f419a48be52b1fe6332eb08d77df0b6ff7ec34f8c99c06e63fa232abf1
3
+ size 39165769
data/pubchemqc_energy.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5ca78b6f81f04ddcc2ed6e031d86f0a2f1e38d6c4001bfd93a28005b7168cf8
3
+ size 89749991
data/qm9_zinc250k_cep/convert_to_parquet.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import requests
3
+ import hashlib
4
+ import os
5
+ # Download and read zinc_properties file
6
+ zinc_url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
7
+ zinc_md5 = "b59078b2b04c6e9431280e3dc42048d5"
8
+ zinc_filename = "zinc_properties.csv"
9
+
10
+ response = requests.get(zinc_url)
11
+ downloaded_data = response.content
12
+
13
+ downloaded_md5 = hashlib.md5(downloaded_data).hexdigest()
14
+ if zinc_md5 == downloaded_md5:
15
+ with open(zinc_filename, 'wb') as f:
16
+ f.write(downloaded_data)
17
+ print(f"File '{zinc_filename}' downloaded and saved.")
18
+ else:
19
+ raise ValueError("MD5 checksum does not match")
20
+
21
+ zinc_df = pd.read_csv(zinc_filename)
22
+ zinc_df = zinc_df[["smiles"]]
23
+
24
+ cwd = os.path.dirname(__file__)
25
+
26
+ qm9_filename = os.path.join(cwd,"QM9IsoFull.csv")
27
+ cep_filename = os.path.join(cwd,"cep-processed.csv")
28
+
29
+ qm9_df = pd.read_csv(qm9_filename, sep="|")
30
+ qm9_df = qm9_df[["smiles"]]
31
+
32
+ cep_df = pd.read_csv(cep_filename)
33
+ cep_df = cep_df[["smiles"]]
34
+
35
+ # Combine the dataframes into one large dataframe
36
+ combined_df = pd.concat([zinc_df, qm9_df, cep_df], axis=0)
37
+
38
+ # Save the combined dataframe to a Parquet file
39
+ output_filename = "qm9_zinc250_cep.parquet"
40
+ combined_df.to_parquet(output_filename, index=False)
41
+ print(f"Combined dataframe saved to '{output_filename}' as Parquet file.")
data/qm9_zinc250k_cep/qm9_zinc250_cep.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3003c48cff3793646f07692b85745786d4d9b103323b3b59b3ae5b23af071d3a
3
+ size 7580076
data/vocab.txt ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [unused1]
3
+ [unused2]
4
+ [unused3]
5
+ [unused4]
6
+ [unused5]
7
+ [unused6]
8
+ [unused7]
9
+ [unused8]
10
+ [unused9]
11
+ [unused10]
12
+ [UNK]
13
+ [CLS]
14
+ [SEP]
15
+ [MASK]
16
+ c
17
+ C
18
+ (
19
+ )
20
+ O
21
+ 1
22
+ 2
23
+ =
24
+ N
25
+ .
26
+ n
27
+ 3
28
+ F
29
+ Cl
30
+ >>
31
+ ~
32
+ -
33
+ 4
34
+ [C@H]
35
+ S
36
+ [C@@H]
37
+ [O-]
38
+ Br
39
+ #
40
+ /
41
+ [nH]
42
+ [N+]
43
+ s
44
+ 5
45
+ o
46
+ P
47
+ [Na+]
48
+ [Si]
49
+ I
50
+ [Na]
51
+ [Pd]
52
+ [K+]
53
+ [K]
54
+ [P]
55
+ B
56
+ [C@]
57
+ [C@@]
58
+ [Cl-]
59
+ 6
60
+ [OH-]
61
+ \
62
+ [N-]
63
+ [Li]
64
+ [H]
65
+ [2H]
66
+ [NH4+]
67
+ [c-]
68
+ [P-]
69
+ [Cs+]
70
+ [Li+]
71
+ [Cs]
72
+ [NaH]
73
+ [H-]
74
+ [O+]
75
+ [BH4-]
76
+ [Cu]
77
+ 7
78
+ [Mg]
79
+ [Fe+2]
80
+ [n+]
81
+ [Sn]
82
+ [BH-]
83
+ [Pd+2]
84
+ [CH]
85
+ [I-]
86
+ [Br-]
87
+ [C-]
88
+ [Zn]
89
+ [B-]
90
+ [F-]
91
+ [Al]
92
+ [P+]
93
+ [BH3-]
94
+ [Fe]
95
+ [C]
96
+ [AlH4]
97
+ [Ni]
98
+ [SiH]
99
+ 8
100
+ [Cu+2]
101
+ [Mn]
102
+ [AlH]
103
+ [nH+]
104
+ [AlH4-]
105
+ [O-2]
106
+ [Cr]
107
+ [Mg+2]
108
+ [NH3+]
109
+ [S@]
110
+ [Pt]
111
+ [Al+3]
112
+ [S@@]
113
+ [S-]
114
+ [Ti]
115
+ [Zn+2]
116
+ [PH]
117
+ [NH2+]
118
+ [Ru]
119
+ [Ag+]
120
+ [S+]
121
+ [I+3]
122
+ [NH+]
123
+ [Ca+2]
124
+ [Ag]
125
+ 9
126
+ [Os]
127
+ [Se]
128
+ [SiH2]
129
+ [Ca]
130
+ [Ti+4]
131
+ [Ac]
132
+ [Cu+]
133
+ [S]
134
+ [Rh]
135
+ [Cl+3]
136
+ [cH-]
137
+ [Zn+]
138
+ [O]
139
+ [Cl+]
140
+ [SH]
141
+ [H+]
142
+ [Pd+]
143
+ [se]
144
+ [PH+]
145
+ [I]
146
+ [Pt+2]
147
+ [C+]
148
+ [Mg+]
149
+ [Hg]
150
+ [W]
151
+ [SnH]
152
+ [SiH3]
153
+ [Fe+3]
154
+ [NH]
155
+ [Mo]
156
+ [CH2+]
157
+ %10
158
+ [CH2-]
159
+ [CH2]
160
+ [n-]
161
+ [Ce+4]
162
+ [NH-]
163
+ [Co]
164
+ [I+]
165
+ [PH2]
166
+ [Pt+4]
167
+ [Ce]
168
+ [B]
169
+ [Sn+2]
170
+ [Ba+2]
171
+ %11
172
+ [Fe-3]
173
+ [18F]
174
+ [SH-]
175
+ [Pb+2]
176
+ [Os-2]
177
+ [Zr+4]
178
+ [N]
179
+ [Ir]
180
+ [Bi]
181
+ [Ni+2]
182
+ [P@]
183
+ [Co+2]
184
+ [s+]
185
+ [As]
186
+ [P+3]
187
+ [Hg+2]
188
+ [Yb+3]
189
+ [CH-]
190
+ [Zr+2]
191
+ [Mn+2]
192
+ [CH+]
193
+ [In]
194
+ [KH]
195
+ [Ce+3]
196
+ [Zr]
197
+ [AlH2-]
198
+ [OH2+]
199
+ [Ti+3]
200
+ [Rh+2]
201
+ [Sb]
202
+ [S-2]
203
+ %12
204
+ [P@@]
205
+ [Si@H]
206
+ [Mn+4]
207
+ p
208
+ [Ba]
209
+ [NH2-]
210
+ [Ge]
211
+ [Pb+4]
212
+ [Cr+3]
213
+ [Au]
214
+ [LiH]
215
+ [Sc+3]
216
+ [o+]
217
+ [Rh-3]
218
+ %13
219
+ [Br]
220
+ [Sb-]
221
+ [S@+]
222
+ [I+2]
223
+ [Ar]
224
+ [V]
225
+ [Cu-]
226
+ [Al-]
227
+ [Te]
228
+ [13c]
229
+ [13C]
230
+ [Cl]
231
+ [PH4+]
232
+ [SiH4]
233
+ [te]
234
+ [CH3-]
235
+ [S@@+]
236
+ [Rh+3]
237
+ [SH+]
238
+ [Bi+3]
239
+ [Br+2]
240
+ [La]
241
+ [La+3]
242
+ [Pt-2]
243
+ [N@@]
244
+ [PH3+]
245
+ [N@]
246
+ [Si+4]
247
+ [Sr+2]
248
+ [Al+]
249
+ [Pb]
250
+ [SeH]
251
+ [Si-]
252
+ [V+5]
253
+ [Y+3]
254
+ [Re]
255
+ [Ru+]
256
+ [Sm]
257
+ *
258
+ [3H]
259
+ [NH2]
260
+ [Ag-]
261
+ [13CH3]
262
+ [OH+]
263
+ [Ru+3]
264
+ [OH]
265
+ [Gd+3]
266
+ [13CH2]
267
+ [In+3]
268
+ [Si@@]
269
+ [Si@]
270
+ [Ti+2]
271
+ [Sn+]
272
+ [Cl+2]
273
+ [AlH-]
274
+ [Pd-2]
275
+ [SnH3]
276
+ [B+3]
277
+ [Cu-2]
278
+ [Nd+3]
279
+ [Pb+3]
280
+ [13cH]
281
+ [Fe-4]
282
+ [Ga]
283
+ [Sn+4]
284
+ [Hg+]
285
+ [11CH3]
286
+ [Hf]
287
+ [Pr]
288
+ [Y]
289
+ [S+2]
290
+ [Cd]
291
+ [Cr+6]
292
+ [Zr+3]
293
+ [Rh+]
294
+ [CH3]
295
+ [N-3]
296
+ [Hf+2]
297
+ [Th]
298
+ [Sb+3]
299
+ %14
300
+ [Cr+2]
301
+ [Ru+2]
302
+ [Hf+4]
303
+ [14C]
304
+ [Ta]
305
+ [Tl+]
306
+ [B+]
307
+ [Os+4]
308
+ [PdH2]
309
+ [Pd-]
310
+ [Cd+2]
311
+ [Co+3]
312
+ [S+4]
313
+ [Nb+5]
314
+ [123I]
315
+ [c+]
316
+ [Rb+]
317
+ [V+2]
318
+ [CH3+]
319
+ [Ag+2]
320
+ [cH+]
321
+ [Mn+3]
322
+ [Se-]
323
+ [As-]
324
+ [Eu+3]
325
+ [SH2]
326
+ [Sm+3]
327
+ [IH+]
328
+ %15
329
+ [OH3+]
330
+ [PH3]
331
+ [IH2+]
332
+ [SH2+]
333
+ [Ir+3]
334
+ [AlH3]
335
+ [Sc]
336
+ [Yb]
337
+ [15NH2]
338
+ [Lu]
339
+ [sH+]
340
+ [Gd]
341
+ [18F-]
342
+ [SH3+]
343
+ [SnH4]
344
+ [TeH]
345
+ [Si@@H]
346
+ [Ga+3]
347
+ [CaH2]
348
+ [Tl]
349
+ [Ta+5]
350
+ [GeH]
351
+ [Br+]
352
+ [Sr]
353
+ [Tl+3]
354
+ [Sm+2]
355
+ [PH5]
356
+ %16
357
+ [N@@+]
358
+ [Au+3]
359
+ [C-4]
360
+ [Nd]
361
+ [Ti+]
362
+ [IH]
363
+ [N@+]
364
+ [125I]
365
+ [Eu]
366
+ [Sn+3]
367
+ [Nb]
368
+ [Er+3]
369
+ [123I-]
370
+ [14c]
371
+ %17
372
+ [SnH2]
373
+ [YH]
374
+ [Sb+5]
375
+ [Pr+3]
376
+ [Ir+]
377
+ [N+3]
378
+ [AlH2]
379
+ [19F]
380
+ %18
381
+ [Tb]
382
+ [14CH]
383
+ [Mo+4]
384
+ [Si+]
385
+ [BH]
386
+ [Be]
387
+ [Rb]
388
+ [pH]
389
+ %19
390
+ %20
391
+ [Xe]
392
+ [Ir-]
393
+ [Be+2]
394
+ [C+4]
395
+ [RuH2]
396
+ [15NH]
397
+ [U+2]
398
+ [Au-]
399
+ %21
400
+ %22
401
+ [Au+]
402
+ [15n]
403
+ [Al+2]
404
+ [Tb+3]
405
+ [15N]
406
+ [V+3]
407
+ [W+6]
408
+ [14CH3]
409
+ [Cr+4]
410
+ [ClH+]
411
+ b
412
+ [Ti+6]
413
+ [Nd+]
414
+ [Zr+]
415
+ [PH2+]
416
+ [Fm]
417
+ [N@H+]
418
+ [RuH]
419
+ [Dy+3]
420
+ %23
421
+ [Hf+3]
422
+ [W+4]
423
+ [11C]
424
+ [13CH]
425
+ [Er]
426
+ [124I]
427
+ [LaH]
428
+ [F]
429
+ [siH]
430
+ [Ga+]
431
+ [Cm]
432
+ [GeH3]
433
+ [IH-]
434
+ [U+6]
435
+ [SeH+]
436
+ [32P]
437
+ [SeH-]
438
+ [Pt-]
439
+ [Ir+2]
440
+ [se+]
441
+ [U]
442
+ [F+]
443
+ [BH2]
444
+ [As+]
445
+ [Cf]
446
+ [ClH2+]
447
+ [Ni+]
448
+ [TeH3]
449
+ [SbH2]
450
+ [Ag+3]
451
+ %24
452
+ [18O]
453
+ [PH4]
454
+ [Os+2]
455
+ [Na-]
456
+ [Sb+2]
457
+ [V+4]
458
+ [Ho+3]
459
+ [68Ga]
460
+ [PH-]
461
+ [Bi+2]
462
+ [Ce+2]
463
+ [Pd+3]
464
+ [99Tc]
465
+ [13C@@H]
466
+ [Fe+6]
467
+ [c]
468
+ [GeH2]
469
+ [10B]
470
+ [Cu+3]
471
+ [Mo+2]
472
+ [Cr+]
473
+ [Pd+4]
474
+ [Dy]
475
+ [AsH]
476
+ [Ba+]
477
+ [SeH2]
478
+ [In+]
479
+ [TeH2]
480
+ [BrH+]
481
+ [14cH]
482
+ [W+]
483
+ [13C@H]
484
+ [AsH2]
485
+ [In+2]
486
+ [N+2]
487
+ [N@@H+]
488
+ [SbH]
489
+ [60Co]
490
+ [AsH4+]
491
+ [AsH3]
492
+ [18OH]
493
+ [Ru-2]
494
+ [Na-2]
495
+ [CuH2]
496
+ [31P]
497
+ [Ti+5]
498
+ [35S]
499
+ [P@@H]
500
+ [ArH]
501
+ [Co+]
502
+ [Zr-2]
503
+ [BH2-]
504
+ [131I]
505
+ [SH5]
506
+ [VH]
507
+ [B+2]
508
+ [Yb+2]
509
+ [14C@H]
510
+ [211At]
511
+ [NH3+2]
512
+ [IrH]
513
+ [IrH2]
514
+ [Rh-]
515
+ [Cr-]
516
+ [Sb+]
517
+ [Ni+3]
518
+ [TaH3]
519
+ [Tl+2]
520
+ [64Cu]
521
+ [Tc]
522
+ [Cd+]
523
+ [1H]
524
+ [15nH]
525
+ [AlH2+]
526
+ [FH+2]
527
+ [BiH3]
528
+ [Ru-]
529
+ [Mo+6]
530
+ [AsH+]
531
+ [BaH2]
532
+ [BaH]
533
+ [Fe+4]
534
+ [229Th]
535
+ [Th+4]
536
+ [As+3]
537
+ [NH+3]
538
+ [P@H]
539
+ [Li-]
540
+ [7NaH]
541
+ [Bi+]
542
+ [PtH+2]
543
+ [p-]
544
+ [Re+5]
545
+ [NiH]
546
+ [Ni-]
547
+ [Xe+]
548
+ [Ca+]
549
+ [11c]
550
+ [Rh+4]
551
+ [AcH]
552
+ [HeH]
553
+ [Sc+2]
554
+ [Mn+]
555
+ [UH]
556
+ [14CH2]
557
+ [SiH4+]
558
+ [18OH2]
559
+ [Ac-]
560
+ [Re+4]
561
+ [118Sn]
562
+ [153Sm]
563
+ [P+2]
564
+ [9CH]
565
+ [9CH3]
566
+ [Y-]
567
+ [NiH2]
568
+ [Si+2]
569
+ [Mn+6]
570
+ [ZrH2]
571
+ [C-2]
572
+ [Bi+5]
573
+ [24NaH]
574
+ [Fr]
575
+ [15CH]
576
+ [Se+]
577
+ [At]
578
+ [P-3]
579
+ [124I-]
580
+ [CuH2-]
581
+ [Nb+4]
582
+ [Nb+3]
583
+ [MgH]
584
+ [Ir+4]
585
+ [67Ga+3]
586
+ [67Ga]
587
+ [13N]
588
+ [15OH2]
589
+ [2NH]
590
+ [Ho]
591
+ [Cn]
592
+ [0*]
593
+ [1*]
594
+ [2*]
595
+ [3*]
596
+ [4*]
597
+ [5*]
598
+ [6*]
599
+ [7*]
600
+ [8*]
601
+ [9*]
602
+ [10*]
603
+ [11*]
604
+ [12*]
605
+ [13*]
606
+ [14*]
607
+ [15*]
608
+ [16*]
609
+ [17*]
610
+ [18*]
611
+ [19*]
612
+ [20*]
data/zinc/convert_to_parquet.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os.path as osp
3
+ import os
4
+ from tqdm import tqdm
5
+ import dask.dataframe as dd
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+ import shutil
9
+ cwd = osp.abspath(osp.dirname(__file__))
10
+ zinc_path = os.path.join(cwd, "zinc_complete")
11
+ alls_dirs = [
12
+ osp.join(zinc_path, f)
13
+ for f in os.listdir(zinc_path)
14
+ if osp.isdir(osp.join(zinc_path, f))
15
+ ]
16
+
17
+
18
+ print("Number of dirs: ", len(alls_dirs))
19
+ all_dfs = []
20
+ for d in alls_dirs:
21
+ print(f"Read: {d }")
22
+ df = dd.read_csv(
23
+ os.path.join(cwd, "zinc_complete", f"{d}/*.txt"),
24
+ sep="\t",
25
+ usecols=["smiles"],
26
+ )
27
+ all_dfs.append(df)
28
+
29
+ concatenated_df = dd.concat(all_dfs)
30
+ # res = df["logp"].map_partitions(lambda d, bins: pd.cut(d, bins), 25).compute()
31
+ # print(res)
32
+
33
+ print("Writing")
34
+ # print(df)
35
+ # name_function = lambda x: f"zincfull-{x}.parquet"
36
+ concatenated_df = concatenated_df.repartition(npartitions=1)
37
+ concatenated_df = concatenated_df.reset_index(drop=True)
38
+ concatenated_df.to_parquet(
39
+ os.path.join(cwd, "zinc_processed"),
40
+ )
41
+ print("Done Writing")
42
+ print(len(concatenated_df))
43
+ shutil.copy(
44
+ os.path.join(cwd, "zinc_processed", "part.0.parquet"),
45
+ os.path.join(cwd, "zinc_processed.parquet")
46
+ )
47
+
48
+ # df = None
49
+ # for d in tqdm(alls_dirs):
50
+ # if df is not None:
51
+ # print(len(df))
52
+ # files = [osp.join(d,f) for f in os.listdir(d)]
53
+ # for f in files:
54
+ # try:
55
+ # df_extra = pd.read_csv(f,sep="\t")
56
+ # except Exception as e:
57
+ # print(f"Got error {f}: {e}")
58
+ # continue
59
+ # # print(df)
60
+ # if df is None:
61
+ # df = df_extra
62
+
63
+ # else:
64
+ # df = df.append(df_extra)
65
+
66
+
67
+ # df.to_parquet(osp.join(cwd, "zinc_combined.parquet"))
data/zinc/zinc_complete/download_zinc.sh ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAA.txt -O AA/AAAA.txt
2
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAB.txt -O AA/AAAB.txt
3
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAC.txt -O AA/AAAC.txt
4
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAD.txt -O AA/AAAD.txt
5
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AABA.txt -O AA/AABA.txt
6
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AABB.txt -O AA/AABB.txt
7
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AABC.txt -O AA/AABC.txt
8
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AABD.txt -O AA/AABD.txt
9
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AACA.txt -O AA/AACA.txt
10
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AACB.txt -O AA/AACB.txt
11
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AACC.txt -O AA/AACC.txt
12
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AACD.txt -O AA/AACD.txt
13
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAEA.txt -O AA/AAEA.txt
14
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAEB.txt -O AA/AAEB.txt
15
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAEC.txt -O AA/AAEC.txt
16
+ mkdir -pv AA && wget http://files.docking.org/2D/AA/AAED.txt -O AA/AAED.txt
17
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAA.txt -O BA/BAAA.txt
18
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAB.txt -O BA/BAAB.txt
19
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAC.txt -O BA/BAAC.txt
20
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAD.txt -O BA/BAAD.txt
21
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BABA.txt -O BA/BABA.txt
22
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BABB.txt -O BA/BABB.txt
23
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BABC.txt -O BA/BABC.txt
24
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BABD.txt -O BA/BABD.txt
25
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BACA.txt -O BA/BACA.txt
26
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BACB.txt -O BA/BACB.txt
27
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BACC.txt -O BA/BACC.txt
28
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BACD.txt -O BA/BACD.txt
29
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAEA.txt -O BA/BAEA.txt
30
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAEB.txt -O BA/BAEB.txt
31
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAEC.txt -O BA/BAEC.txt
32
+ mkdir -pv BA && wget http://files.docking.org/2D/BA/BAED.txt -O BA/BAED.txt
33
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAA.txt -O CA/CAAA.txt
34
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAB.txt -O CA/CAAB.txt
35
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAC.txt -O CA/CAAC.txt
36
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAD.txt -O CA/CAAD.txt
37
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CABA.txt -O CA/CABA.txt
38
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CABB.txt -O CA/CABB.txt
39
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CABC.txt -O CA/CABC.txt
40
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CABD.txt -O CA/CABD.txt
41
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CACA.txt -O CA/CACA.txt
42
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CACB.txt -O CA/CACB.txt
43
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CACC.txt -O CA/CACC.txt
44
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CACD.txt -O CA/CACD.txt
45
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAEA.txt -O CA/CAEA.txt
46
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAEB.txt -O CA/CAEB.txt
47
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAEC.txt -O CA/CAEC.txt
48
+ mkdir -pv CA && wget http://files.docking.org/2D/CA/CAED.txt -O CA/CAED.txt
49
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAA.txt -O DA/DAAA.txt
50
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAB.txt -O DA/DAAB.txt
51
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAC.txt -O DA/DAAC.txt
52
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAD.txt -O DA/DAAD.txt
53
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DABA.txt -O DA/DABA.txt
54
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DABB.txt -O DA/DABB.txt
55
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DABC.txt -O DA/DABC.txt
56
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DABD.txt -O DA/DABD.txt
57
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DACA.txt -O DA/DACA.txt
58
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DACB.txt -O DA/DACB.txt
59
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DACC.txt -O DA/DACC.txt
60
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DACD.txt -O DA/DACD.txt
61
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAEA.txt -O DA/DAEA.txt
62
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAEB.txt -O DA/DAEB.txt
63
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAEC.txt -O DA/DAEC.txt
64
+ mkdir -pv DA && wget http://files.docking.org/2D/DA/DAED.txt -O DA/DAED.txt
65
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAA.txt -O EA/EAAA.txt
66
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAB.txt -O EA/EAAB.txt
67
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAC.txt -O EA/EAAC.txt
68
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAD.txt -O EA/EAAD.txt
69
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EABA.txt -O EA/EABA.txt
70
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EABB.txt -O EA/EABB.txt
71
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EABC.txt -O EA/EABC.txt
72
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EABD.txt -O EA/EABD.txt
73
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EACA.txt -O EA/EACA.txt
74
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EACB.txt -O EA/EACB.txt
75
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EACC.txt -O EA/EACC.txt
76
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EACD.txt -O EA/EACD.txt
77
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAEA.txt -O EA/EAEA.txt
78
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAEB.txt -O EA/EAEB.txt
79
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAEC.txt -O EA/EAEC.txt
80
+ mkdir -pv EA && wget http://files.docking.org/2D/EA/EAED.txt -O EA/EAED.txt
81
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAA.txt -O FA/FAAA.txt
82
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAB.txt -O FA/FAAB.txt
83
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAC.txt -O FA/FAAC.txt
84
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAD.txt -O FA/FAAD.txt
85
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FABA.txt -O FA/FABA.txt
86
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FABB.txt -O FA/FABB.txt
87
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FABC.txt -O FA/FABC.txt
88
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FABD.txt -O FA/FABD.txt
89
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FACA.txt -O FA/FACA.txt
90
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FACB.txt -O FA/FACB.txt
91
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FACC.txt -O FA/FACC.txt
92
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FACD.txt -O FA/FACD.txt
93
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAEA.txt -O FA/FAEA.txt
94
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAEB.txt -O FA/FAEB.txt
95
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAEC.txt -O FA/FAEC.txt
96
+ mkdir -pv FA && wget http://files.docking.org/2D/FA/FAED.txt -O FA/FAED.txt
97
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAA.txt -O GA/GAAA.txt
98
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAB.txt -O GA/GAAB.txt
99
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAC.txt -O GA/GAAC.txt
100
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAD.txt -O GA/GAAD.txt
101
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAA.txt -O AB/ABAA.txt
102
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAB.txt -O AB/ABAB.txt
103
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAC.txt -O AB/ABAC.txt
104
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAD.txt -O AB/ABAD.txt
105
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBA.txt -O AB/ABBA.txt
106
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBB.txt -O AB/ABBB.txt
107
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBC.txt -O AB/ABBC.txt
108
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBD.txt -O AB/ABBD.txt
109
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCA.txt -O AB/ABCA.txt
110
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCB.txt -O AB/ABCB.txt
111
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCC.txt -O AB/ABCC.txt
112
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCD.txt -O AB/ABCD.txt
113
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABEA.txt -O AB/ABEA.txt
114
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABEB.txt -O AB/ABEB.txt
115
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABEC.txt -O AB/ABEC.txt
116
+ mkdir -pv AB && wget http://files.docking.org/2D/AB/ABED.txt -O AB/ABED.txt
117
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAA.txt -O BB/BBAA.txt
118
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAB.txt -O BB/BBAB.txt
119
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAC.txt -O BB/BBAC.txt
120
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAD.txt -O BB/BBAD.txt
121
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBA.txt -O BB/BBBA.txt
122
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBB.txt -O BB/BBBB.txt
123
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBC.txt -O BB/BBBC.txt
124
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBD.txt -O BB/BBBD.txt
125
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GABA.txt -O GA/GABA.txt
126
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GABB.txt -O GA/GABB.txt
127
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GABC.txt -O GA/GABC.txt
128
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GABD.txt -O GA/GABD.txt
129
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GACA.txt -O GA/GACA.txt
130
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GACB.txt -O GA/GACB.txt
131
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GACC.txt -O GA/GACC.txt
132
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GACD.txt -O GA/GACD.txt
133
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAEA.txt -O GA/GAEA.txt
134
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAEB.txt -O GA/GAEB.txt
135
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAEC.txt -O GA/GAEC.txt
136
+ mkdir -pv GA && wget http://files.docking.org/2D/GA/GAED.txt -O GA/GAED.txt
137
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAA.txt -O HA/HAAA.txt
138
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAB.txt -O HA/HAAB.txt
139
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAC.txt -O HA/HAAC.txt
140
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAD.txt -O HA/HAAD.txt
141
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HABA.txt -O HA/HABA.txt
142
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HABB.txt -O HA/HABB.txt
143
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HABC.txt -O HA/HABC.txt
144
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HABD.txt -O HA/HABD.txt
145
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HACA.txt -O HA/HACA.txt
146
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HACB.txt -O HA/HACB.txt
147
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HACC.txt -O HA/HACC.txt
148
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HACD.txt -O HA/HACD.txt
149
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAEA.txt -O HA/HAEA.txt
150
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAEB.txt -O HA/HAEB.txt
151
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAEC.txt -O HA/HAEC.txt
152
+ mkdir -pv HA && wget http://files.docking.org/2D/HA/HAED.txt -O HA/HAED.txt
153
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAA.txt -O IA/IAAA.txt
154
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAB.txt -O IA/IAAB.txt
155
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAC.txt -O IA/IAAC.txt
156
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAD.txt -O IA/IAAD.txt
157
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IABA.txt -O IA/IABA.txt
158
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IABB.txt -O IA/IABB.txt
159
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IABC.txt -O IA/IABC.txt
160
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IABD.txt -O IA/IABD.txt
161
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IACA.txt -O IA/IACA.txt
162
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IACB.txt -O IA/IACB.txt
163
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IACC.txt -O IA/IACC.txt
164
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IACD.txt -O IA/IACD.txt
165
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAEA.txt -O IA/IAEA.txt
166
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAEB.txt -O IA/IAEB.txt
167
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAEC.txt -O IA/IAEC.txt
168
+ mkdir -pv IA && wget http://files.docking.org/2D/IA/IAED.txt -O IA/IAED.txt
169
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAA.txt -O JA/JAAA.txt
170
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAB.txt -O JA/JAAB.txt
171
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAC.txt -O JA/JAAC.txt
172
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAD.txt -O JA/JAAD.txt
173
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JABA.txt -O JA/JABA.txt
174
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JABB.txt -O JA/JABB.txt
175
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JABC.txt -O JA/JABC.txt
176
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JABD.txt -O JA/JABD.txt
177
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JACA.txt -O JA/JACA.txt
178
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JACB.txt -O JA/JACB.txt
179
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JACC.txt -O JA/JACC.txt
180
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JACD.txt -O JA/JACD.txt
181
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAEA.txt -O JA/JAEA.txt
182
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAEB.txt -O JA/JAEB.txt
183
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAEC.txt -O JA/JAEC.txt
184
+ mkdir -pv JA && wget http://files.docking.org/2D/JA/JAED.txt -O JA/JAED.txt
185
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAA.txt -O KA/KAAA.txt
186
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAB.txt -O KA/KAAB.txt
187
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAC.txt -O KA/KAAC.txt
188
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAD.txt -O KA/KAAD.txt
189
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KABA.txt -O KA/KABA.txt
190
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KABB.txt -O KA/KABB.txt
191
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KABC.txt -O KA/KABC.txt
192
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KABD.txt -O KA/KABD.txt
193
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KACA.txt -O KA/KACA.txt
194
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KACB.txt -O KA/KACB.txt
195
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KACC.txt -O KA/KACC.txt
196
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KACD.txt -O KA/KACD.txt
197
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAEA.txt -O KA/KAEA.txt
198
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAEB.txt -O KA/KAEB.txt
199
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAEC.txt -O KA/KAEC.txt
200
+ mkdir -pv KA && wget http://files.docking.org/2D/KA/KAED.txt -O KA/KAED.txt
201
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCA.txt -O BB/BBCA.txt
202
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCB.txt -O BB/BBCB.txt
203
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCC.txt -O BB/BBCC.txt
204
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCD.txt -O BB/BBCD.txt
205
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBEA.txt -O BB/BBEA.txt
206
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBEB.txt -O BB/BBEB.txt
207
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBEC.txt -O BB/BBEC.txt
208
+ mkdir -pv BB && wget http://files.docking.org/2D/BB/BBED.txt -O BB/BBED.txt
209
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAA.txt -O CB/CBAA.txt
210
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAB.txt -O CB/CBAB.txt
211
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAC.txt -O CB/CBAC.txt
212
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAD.txt -O CB/CBAD.txt
213
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBA.txt -O CB/CBBA.txt
214
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBB.txt -O CB/CBBB.txt
215
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBC.txt -O CB/CBBC.txt
216
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBD.txt -O CB/CBBD.txt
217
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCA.txt -O CB/CBCA.txt
218
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCB.txt -O CB/CBCB.txt
219
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCC.txt -O CB/CBCC.txt
220
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCD.txt -O CB/CBCD.txt
221
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBEA.txt -O CB/CBEA.txt
222
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBEB.txt -O CB/CBEB.txt
223
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBEC.txt -O CB/CBEC.txt
224
+ mkdir -pv CB && wget http://files.docking.org/2D/CB/CBED.txt -O CB/CBED.txt
225
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAA.txt -O DB/DBAA.txt
226
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAB.txt -O DB/DBAB.txt
227
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAC.txt -O DB/DBAC.txt
228
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAD.txt -O DB/DBAD.txt
229
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBA.txt -O DB/DBBA.txt
230
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBB.txt -O DB/DBBB.txt
231
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBC.txt -O DB/DBBC.txt
232
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBD.txt -O DB/DBBD.txt
233
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCA.txt -O DB/DBCA.txt
234
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCB.txt -O DB/DBCB.txt
235
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCC.txt -O DB/DBCC.txt
236
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCD.txt -O DB/DBCD.txt
237
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBEA.txt -O DB/DBEA.txt
238
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBEB.txt -O DB/DBEB.txt
239
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBEC.txt -O DB/DBEC.txt
240
+ mkdir -pv DB && wget http://files.docking.org/2D/DB/DBED.txt -O DB/DBED.txt
241
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAA.txt -O EB/EBAA.txt
242
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAB.txt -O EB/EBAB.txt
243
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAC.txt -O EB/EBAC.txt
244
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAD.txt -O EB/EBAD.txt
245
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBA.txt -O EB/EBBA.txt
246
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBB.txt -O EB/EBBB.txt
247
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBC.txt -O EB/EBBC.txt
248
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBD.txt -O EB/EBBD.txt
249
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCA.txt -O EB/EBCA.txt
250
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCB.txt -O EB/EBCB.txt
251
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCC.txt -O EB/EBCC.txt
252
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCD.txt -O EB/EBCD.txt
253
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBEA.txt -O EB/EBEA.txt
254
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBEB.txt -O EB/EBEB.txt
255
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBEC.txt -O EB/EBEC.txt
256
+ mkdir -pv EB && wget http://files.docking.org/2D/EB/EBED.txt -O EB/EBED.txt
257
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAA.txt -O FB/FBAA.txt
258
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAB.txt -O FB/FBAB.txt
259
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAC.txt -O FB/FBAC.txt
260
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAD.txt -O FB/FBAD.txt
261
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBA.txt -O FB/FBBA.txt
262
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBB.txt -O FB/FBBB.txt
263
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBC.txt -O FB/FBBC.txt
264
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBD.txt -O FB/FBBD.txt
265
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCA.txt -O FB/FBCA.txt
266
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCB.txt -O FB/FBCB.txt
267
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCC.txt -O FB/FBCC.txt
268
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCD.txt -O FB/FBCD.txt
269
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBEA.txt -O FB/FBEA.txt
270
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBEB.txt -O FB/FBEB.txt
271
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBEC.txt -O FB/FBEC.txt
272
+ mkdir -pv FB && wget http://files.docking.org/2D/FB/FBED.txt -O FB/FBED.txt
273
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAA.txt -O GB/GBAA.txt
274
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAB.txt -O GB/GBAB.txt
275
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAC.txt -O GB/GBAC.txt
276
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAD.txt -O GB/GBAD.txt
277
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBA.txt -O GB/GBBA.txt
278
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBB.txt -O GB/GBBB.txt
279
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBC.txt -O GB/GBBC.txt
280
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBD.txt -O GB/GBBD.txt
281
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCA.txt -O GB/GBCA.txt
282
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCB.txt -O GB/GBCB.txt
283
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCC.txt -O GB/GBCC.txt
284
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCD.txt -O GB/GBCD.txt
285
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBEA.txt -O GB/GBEA.txt
286
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBEB.txt -O GB/GBEB.txt
287
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBEC.txt -O GB/GBEC.txt
288
+ mkdir -pv GB && wget http://files.docking.org/2D/GB/GBED.txt -O GB/GBED.txt
289
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAA.txt -O HB/HBAA.txt
290
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAB.txt -O HB/HBAB.txt
291
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAC.txt -O HB/HBAC.txt
292
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAD.txt -O HB/HBAD.txt
293
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBA.txt -O HB/HBBA.txt
294
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBB.txt -O HB/HBBB.txt
295
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBC.txt -O HB/HBBC.txt
296
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBD.txt -O HB/HBBD.txt
297
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCA.txt -O HB/HBCA.txt
298
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCB.txt -O HB/HBCB.txt
299
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCC.txt -O HB/HBCC.txt
300
+ mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCD.txt -O HB/HBCD.txt
data/zinc/zinc_complete/run_download.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import subprocess
3
+
4
+ shell_file = "download_zinc.sh"
5
+ num_parallel = 8
6
+
7
+ def execute_command(command):
8
+ print("Running: ", command)
9
+ subprocess.run(command, shell=True)
10
+
11
+ commands = []
12
+ with open(shell_file, "r") as file:
13
+ for line in file:
14
+ line = line.strip()
15
+ if line.startswith("mkdir") and "wget" in line:
16
+ commands.append(line)
17
+
18
+ with concurrent.futures.ThreadPoolExecutor() as executor:
19
+ executor.map(execute_command, commands, chunksize=num_parallel)
20
+
21
+ print("Downloads completed")
demonstrator.ipynb ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Demonstrator"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "### Load the model"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 5,
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "name": "stderr",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "INFO:sample:Compiling the model...\n"
27
+ ]
28
+ }
29
+ ],
30
+ "source": [
31
+ "import rdkit\n",
32
+ "from rdkit import Chem\n",
33
+ "import rdkit.rdBase as rkrb\n",
34
+ "import rdkit.RDLogger as rkl\n",
35
+ "import os\n",
36
+ "import torch \n",
37
+ "import logging\n",
38
+ "import numpy as np\n",
39
+ "from plot_utils import check_metrics\n",
40
+ "from sample import Sampler\n",
41
+ "import pandas as pd\n",
42
+ "\n",
43
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
44
+ "\n",
45
+ "if \"cuda\" in device:\n",
46
+ " # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'\n",
47
+ " dtype = \"float16\" if torch.cuda.is_available() else \"float32\"\n",
48
+ "else:\n",
49
+ " dtype = \"float32\"\n",
50
+ "\n",
51
+ "logger = rkl.logger()\n",
52
+ "logger.setLevel(rkl.ERROR)\n",
53
+ "rkrb.DisableLog(\"rdApp.error\")\n",
54
+ "\n",
55
+ "torch.set_num_threads(8)\n",
56
+ "logging.basicConfig(level=logging.INFO)\n",
57
+ "logger = logging.getLogger(__name__)\n",
58
+ "\n",
59
+ "sampler = Sampler(\n",
60
+ " load_path=os.path.join(\n",
61
+ " os.getcwd(), \"out\", \"llama2-M-Full-RSS.pt\"\n",
62
+ " ),\n",
63
+ " device=device,\n",
64
+ " seed=1234,\n",
65
+ " dtype=dtype,\n",
66
+ " compile=True,\n",
67
+ ")\n",
68
+ "\n",
69
+ " \n",
70
+ "num_samples = 100\n",
71
+ "df_comp = pd.read_parquet(os.path.join(os.getcwd(),\"data\",\"OrganiX13.parquet\"))\n",
72
+ "df_comp = df_comp.sample(n=2_500_000)\n",
73
+ "comp_context_dict = {c: df_comp[c].to_numpy() for c in [\"logp\", \"sascore\", \"mol_weight\"]} \n",
74
+ "comp_smiles = df_comp[\"smiles\"]\n",
75
+ "\n"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 6,
81
+ "metadata": {},
82
+ "outputs": [
83
+ {
84
+ "name": "stderr",
85
+ "output_type": "stream",
86
+ "text": [
87
+ "INFO:root:Wrote file /home/ndobberstein/Projekte/llama2-molgen/chemiscope_gen.json\n"
88
+ ]
89
+ }
90
+ ],
91
+ "source": [
92
+ "from typing import List, Dict\n",
93
+ "import json\n",
94
+ "from rdkit.Chem import AllChem\n",
95
+ "\n",
96
+ "@torch.no_grad()\n",
97
+ "def convert_to_chemiscope(smiles_list : List[str], context_dict : Dict[str, List[float]]):\n",
98
+ " # For more details on the file format: https://chemiscope.org/docs/tutorial/input-reference.html\n",
99
+ "\n",
100
+ " structures = []\n",
101
+ " remove_list = []\n",
102
+ " for i,smi in enumerate(smiles_list):\n",
103
+ " mol = Chem.MolFromSmiles(smi)\n",
104
+ " if mol is None:\n",
105
+ " logging.info(f\"Mol invalid: {smi} ! Skipping...\")\n",
106
+ " remove_list.append(i)\n",
107
+ " continue\n",
108
+ "\n",
109
+ " res = AllChem.EmbedMolecule(mol,randomSeed=0xf00d, maxAttempts=20)\n",
110
+ " # res = AllChem.Compute2DCoords(mol)\n",
111
+ "\n",
112
+ " if res != 0:\n",
113
+ " logging.info(f\"Could not calculate coordinates for {smi}! Skipping..\")\n",
114
+ " remove_list.append(i)\n",
115
+ " continue\n",
116
+ " \n",
117
+ "\n",
118
+ " conf = list(mol.GetConformers())[0]\n",
119
+ " x,y,z = [],[],[]\n",
120
+ " symbols = []\n",
121
+ " for atom, coords in zip(mol.GetAtoms(), conf.GetPositions()):\n",
122
+ " symbols.append(atom.GetSymbol())\n",
123
+ " x.append(coords[0])\n",
124
+ " y.append(coords[1])\n",
125
+ " z.append(coords[2])\n",
126
+ " \n",
127
+ " structures.append({\n",
128
+ " \"size\": len(x),\n",
129
+ " \"names\": symbols,\n",
130
+ " \"x\": x,\n",
131
+ " \"y\": y,\n",
132
+ " \"z\" : z\n",
133
+ " })\n",
134
+ "\n",
135
+ "\n",
136
+ "\n",
137
+ " properties = {}\n",
138
+ " \n",
139
+ " for c in context_dict:\n",
140
+ " properties[c] = {\n",
141
+ " \"target\": \"structure\",\n",
142
+ " \"values\": [v for i, v in enumerate(context_dict[c]) if i not in remove_list]\n",
143
+ " }\n",
144
+ " \n",
145
+ "\n",
146
+ "\n",
147
+ " \n",
148
+ " data = {\n",
149
+ " \"meta\": {\n",
150
+ " # // the name of the dataset\n",
151
+ " \"name\": \"Test Dataset\",\n",
152
+ " # // description of the dataset, OPTIONAL\n",
153
+ " \"description\": \"This contains data from generated molecules\",\n",
154
+ " # // authors of the dataset, OPTIONAL\n",
155
+ " \"authors\": [\"Niklas Dobberstein, [email protected]\"],\n",
156
+ " # // references for the dataset, OPTIONAL\n",
157
+ " \"references\": [\n",
158
+ " \"\",\n",
159
+ " ],\n",
160
+ " \n",
161
+ " },\n",
162
+ " \"properties\": properties,\n",
163
+ " \"structures\": structures\n",
164
+ " }\n",
165
+ " \n",
166
+ " out_path = os.path.join(os.getcwd(), \"chemiscope_gen.json\")\n",
167
+ " with open(out_path, \"w\") as f:\n",
168
+ " json.dump(data, f)\n",
169
+ "\n",
170
+ " logging.info(f\"Wrote file {out_path}\")\n",
171
+ "\n",
172
+ "convert_to_chemiscope([\n",
173
+ " \"CC=O\",\n",
174
+ " \"s1ccnc1\"\n",
175
+ "], {\"logp\": [1.0,2.0], \"sascore\": [1.5,-2.0]})"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 7,
181
+ "metadata": {},
182
+ "outputs": [
183
+ {
184
+ "data": {
185
+ "application/vnd.jupyter.widget-view+json": {
186
+ "model_id": "8b28a4e692de4bb48fde10a88d9727ba",
187
+ "version_major": 2,
188
+ "version_minor": 0
189
+ },
190
+ "text/plain": [
191
+ "HBox(children=(Checkbox(value=False, description='logp'), Checkbox(value=False, description='sascore'), Checkb…"
192
+ ]
193
+ },
194
+ "metadata": {},
195
+ "output_type": "display_data"
196
+ },
197
+ {
198
+ "data": {
199
+ "application/vnd.jupyter.widget-view+json": {
200
+ "model_id": "62331a62f2bf4d08a3a202ad277c6d92",
201
+ "version_major": 2,
202
+ "version_minor": 0
203
+ },
204
+ "text/plain": [
205
+ "HBox(children=(FloatSlider(value=0.0, description='logp:', max=7.0, min=-4.0, step=0.5), FloatSlider(value=2.0…"
206
+ ]
207
+ },
208
+ "metadata": {},
209
+ "output_type": "display_data"
210
+ },
211
+ {
212
+ "data": {
213
+ "application/vnd.jupyter.widget-view+json": {
214
+ "model_id": "2d498af39f4046b0a5bb92080361dfec",
215
+ "version_major": 2,
216
+ "version_minor": 0
217
+ },
218
+ "text/plain": [
219
+ "Text(value='', description='Context SMI:')"
220
+ ]
221
+ },
222
+ "metadata": {},
223
+ "output_type": "display_data"
224
+ },
225
+ {
226
+ "data": {
227
+ "application/vnd.jupyter.widget-view+json": {
228
+ "model_id": "ed8a755253444e9c83dc27c5f830588b",
229
+ "version_major": 2,
230
+ "version_minor": 0
231
+ },
232
+ "text/plain": [
233
+ "FloatSlider(value=0.8, description='Temperature:', max=2.0)"
234
+ ]
235
+ },
236
+ "metadata": {},
237
+ "output_type": "display_data"
238
+ },
239
+ {
240
+ "data": {
241
+ "application/vnd.jupyter.widget-view+json": {
242
+ "model_id": "139e7d1e40984101800e2cbb740280b0",
243
+ "version_major": 2,
244
+ "version_minor": 0
245
+ },
246
+ "text/plain": [
247
+ "Button(description='Generate', style=ButtonStyle())"
248
+ ]
249
+ },
250
+ "metadata": {},
251
+ "output_type": "display_data"
252
+ },
253
+ {
254
+ "data": {
255
+ "application/vnd.jupyter.widget-view+json": {
256
+ "model_id": "4d119a3b477243ac916478a6ec2a55c7",
257
+ "version_major": 2,
258
+ "version_minor": 0
259
+ },
260
+ "text/plain": [
261
+ "Output()"
262
+ ]
263
+ },
264
+ "metadata": {},
265
+ "output_type": "display_data"
266
+ },
267
+ {
268
+ "data": {
269
+ "application/vnd.jupyter.widget-view+json": {
270
+ "model_id": "dfce28d4f6a3414c838e6542ffb43fc6",
271
+ "version_major": 2,
272
+ "version_minor": 0
273
+ },
274
+ "text/plain": [
275
+ "Output()"
276
+ ]
277
+ },
278
+ "metadata": {},
279
+ "output_type": "display_data"
280
+ }
281
+ ],
282
+ "source": [
283
+ "import ipywidgets as widgets\n",
284
+ "from IPython.display import display, clear_output, HTML\n",
285
+ "import numpy as np\n",
286
+ "import torch\n",
287
+ "import matplotlib.pyplot as plt\n",
288
+ "from rdkit import Chem\n",
289
+ "from rdkit.Chem import Draw\n",
290
+ "import logging\n",
291
+ "from plot_utils import calc_context_from_smiles\n",
292
+ "\n",
293
+ "# Define the context_cols options and create checkboxes for them\n",
294
+ "context_cols_options = [\"logp\", \"sascore\", \"mol_weight\"]\n",
295
+ "context_cols_checkboxes = [widgets.Checkbox(description=col, value=False) for col in context_cols_options]\n",
296
+ "\n",
297
+ "# Create a text input for context_smi\n",
298
+ "context_smi_input = widgets.Text(description=\"Context SMI:\", value=\"\")\n",
299
+ "\n",
300
+ "# Create sliders for temperature and context_cols values\n",
301
+ "temperature_slider = widgets.FloatSlider(description=\"Temperature:\", min=0, max=2.0, step=0.1, value=0.8)\n",
302
+ "\n",
303
+ "logp_slider = widgets.FloatSlider(description=\"logp:\", min=-4, max=7, step=0.5, value=0.0)\n",
304
+ "sascore_slider = widgets.FloatSlider(description=\"sascore:\", min=1, max=10, step=0.5, value=2.0)\n",
305
+ "mol_weight_slider = widgets.FloatSlider(description=\"mol_weight:\", min=0.5, max=10, step=0.5, value=3.0)\n",
306
+ "\n",
307
+ "# Create a button to generate the code and display SMILES\n",
308
+ "generate_button = widgets.Button(description=\"Generate\")\n",
309
+ "\n",
310
+ "# Create an output widget for displaying generated information\n",
311
+ "output = widgets.Output()\n",
312
+ "\n",
313
+ "# Create an output widget for displaying the RDKit molecules\n",
314
+ "molecule_output = widgets.Output()\n",
315
+ "\n",
316
+ "@torch.no_grad()\n",
317
+ "def generate_code(_):\n",
318
+ " with output:\n",
319
+ " clear_output(wait=False)\n",
320
+ " # logging.info(\"Parameters used in generation:\")\n",
321
+ " \n",
322
+ " # Get the selected context_cols\n",
323
+ " selected_context_cols = [col for col, checkbox in zip(context_cols_options, context_cols_checkboxes) if checkbox.value]\n",
324
+ " # logging.info(f\"Context Cols: {selected_context_cols}\")\n",
325
+ " \n",
326
+ " # Get the values of context_smi and temperature from the sliders\n",
327
+ " context_smi = context_smi_input.value.strip()\n",
328
+ " temperature = temperature_slider.value\n",
329
+ " # logging.info(f\"Context Smiles: {context_smi}\")\n",
330
+ " # logging.info(f\"Temperature: {temperature}\")\n",
331
+ " \n",
332
+ " # Get the values of logp, sascore, and mol_weight from the sliders\n",
333
+ " context_dict = {} if len(selected_context_cols) != 0 else None\n",
334
+ " for c in selected_context_cols:\n",
335
+ " if c == \"logp\":\n",
336
+ " val = logp_slider.value\n",
337
+ " elif c == \"sascore\":\n",
338
+ " val = sascore_slider.value\n",
339
+ " else:\n",
340
+ " val = mol_weight_slider.value\n",
341
+ " val = round(val, 2)\n",
342
+ " context_dict[c] = val*torch.ones((num_samples,),device=device,dtype=torch.float)\n",
343
+ " # logging.info(f\"{c}: {val}\")\n",
344
+ " \n",
345
+ " # Generate SMILES using the provided context\n",
346
+ " smiles, context = sampler.generate(\n",
347
+ " context_cols=context_dict,\n",
348
+ " context_smi=context_smi,\n",
349
+ " start_smiles=None,\n",
350
+ " num_samples=num_samples,\n",
351
+ " max_new_tokens=256,\n",
352
+ " temperature=temperature,\n",
353
+ " top_k=25,\n",
354
+ " total_gen_steps=int(np.ceil(num_samples / 1000)),\n",
355
+ " return_context=True\n",
356
+ " )\n",
357
+ " \n",
358
+ " with open(os.path.join(os.getcwd(), \"gen_smiles.txt\"), \"w\") as f:\n",
359
+ " for s in smiles:\n",
360
+ " f.write(f\"{s}\\n\")\n",
361
+ " # Display SMILES as RDKit molecules\n",
362
+ " display_molecules(smiles, context)\n",
363
+ "\n",
364
+ "\n",
365
+ "\n",
366
+ "def display_molecules(smiles_list, context_dict):\n",
367
+ " with molecule_output:\n",
368
+ " clear_output(wait=False)\n",
369
+ " molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]\n",
370
+ " \n",
371
+ " # Convert RDKit molecules to images and store them in a list\n",
372
+ " images = [Draw.MolToImage(mol) for mol in molecules]\n",
373
+ " \n",
374
+ " # Create a subplot grid to display the images\n",
375
+ " num_images = len(images)\n",
376
+ " num_cols = 5 # Number of columns in the grid\n",
377
+ " num_rows = (num_images + num_cols - 1) // num_cols # Calculate the number of rows\n",
378
+ " \n",
379
+ " fig, axes = plt.subplots(num_rows, num_cols, figsize=(25, 25))\n",
380
+ " fig.subplots_adjust(hspace=0.5)\n",
381
+ " calculated_context = {c:[] for c in context_dict}\n",
382
+ " for i, ax in enumerate(axes.flat):\n",
383
+ " if i < num_images:\n",
384
+ " ax.imshow(images[i])\n",
385
+ " for j, c in enumerate(context_dict):\n",
386
+ " smiles = smiles_list[i]\n",
387
+ " smi_con = round(calc_context_from_smiles([smiles], c)[0],2)\n",
388
+ " calculated_context[c].append(smi_con)\n",
389
+ " ax.text(0.5, -0.1 * j , f\"{c}: {context_dict[c][i]} vs {smi_con}\", transform=ax.transAxes, fontsize=10, ha='center')\n",
390
+ " \n",
391
+ " ax.axis('off')\n",
392
+ " else:\n",
393
+ " fig.delaxes(ax) # Remove empty subplots if there are more rows than images\n",
394
+ " \n",
395
+ "\n",
396
+ " if len(context_dict) >= 2:\n",
397
+ " convert_to_chemiscope(smiles_list, calculated_context)\n",
398
+ "\n",
399
+ " plt.savefig(\"gen_mols.png\")\n",
400
+ " plt.show()\n",
401
+ "\n",
402
+ "# Attach the generate_code function to the button's click event\n",
403
+ "generate_button.on_click(generate_code)\n",
404
+ "\n",
405
+ "# Display the widgets\n",
406
+ "display(widgets.HBox(context_cols_checkboxes))\n",
407
+ "display(widgets.HBox((logp_slider, sascore_slider, mol_weight_slider)))\n",
408
+ "\n",
409
+ "display(context_smi_input)\n",
410
+ "display(temperature_slider)\n",
411
+ "display(generate_button)\n",
412
+ "display(output)\n",
413
+ "display(molecule_output)"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "metadata": {},
420
+ "outputs": [
421
+ {
422
+ "data": {
423
+ "application/vnd.jupyter.widget-view+json": {
424
+ "model_id": "ea96e00e0ea8448d97906ec965f04788",
425
+ "version_major": 2,
426
+ "version_minor": 0
427
+ },
428
+ "text/plain": [
429
+ "Batch: 0%| | 0/1 [00:00<?, ?it/s]"
430
+ ]
431
+ },
432
+ "metadata": {},
433
+ "output_type": "display_data"
434
+ },
435
+ {
436
+ "data": {
437
+ "application/vnd.jupyter.widget-view+json": {
438
+ "model_id": "77ba2d72172846e18572c94bc5b3bd6f",
439
+ "version_major": 2,
440
+ "version_minor": 0
441
+ },
442
+ "text/plain": [
443
+ "Generation: 0%| | 0/256 [00:00<?, ?it/s]"
444
+ ]
445
+ },
446
+ "metadata": {},
447
+ "output_type": "display_data"
448
+ },
449
+ {
450
+ "name": "stderr",
451
+ "output_type": "stream",
452
+ "text": [
453
+ "INFO:sample:Number valid generated: 68.0 %\n",
454
+ "INFO:sample:---------------\n"
455
+ ]
456
+ }
457
+ ],
458
+ "source": [
459
+ "selected_context_cols = [\"logp\", \"sascore\", \"mol_weight\"]\n",
460
+ "num_samples = 25\n",
461
+ "context_dict = {} if len(selected_context_cols) != 0 else None\n",
462
+ "for c in selected_context_cols:\n",
463
+ " if c == \"logp\":\n",
464
+ " v = 0.5 * torch.randint(\n",
465
+ " -8, 14, (num_samples,), device=device, dtype=torch.float\n",
466
+ " )\n",
467
+ " context_dict[c] = v.sort()[0]\n",
468
+ " elif c == \"sascore\":\n",
469
+ " v = 0.5 * torch.randint(\n",
470
+ " 1, 20, (num_samples,), device=device, dtype=torch.float\n",
471
+ " )\n",
472
+ " context_dict[c] = v.sort()[0]\n",
473
+ " else:\n",
474
+ " v = 0.5 * torch.randint(\n",
475
+ " 1, 20, (num_samples,), device=device, dtype=torch.float\n",
476
+ " )\n",
477
+ " \n",
478
+ " context_dict[c] = v.sort()[0]\n",
479
+ " # logging.info(f\"{c}: {val}\")\n",
480
+ "\n",
481
+ "# Generate SMILES using the provided context\n",
482
+ "smiles, context = sampler.generate(\n",
483
+ " context_cols=context_dict,\n",
484
+ " context_smi=None,\n",
485
+ " start_smiles=None,\n",
486
+ " num_samples=num_samples,\n",
487
+ " max_new_tokens=256,\n",
488
+ " temperature=0.8,\n",
489
+ " top_k=25,\n",
490
+ " total_gen_steps=int(np.ceil(num_samples / 1000)),\n",
491
+ " return_context=True\n",
492
+ ")\n",
493
+ "\n",
494
+ "# Display SMILES as RDKit molecules\n",
495
+ "display_molecules(smiles, context)\n"
496
+ ]
497
+ }
498
+ ],
499
+ "metadata": {
500
+ "kernelspec": {
501
+ "display_name": "torch2-bachelor",
502
+ "language": "python",
503
+ "name": "python3"
504
+ },
505
+ "language_info": {
506
+ "codemirror_mode": {
507
+ "name": "ipython",
508
+ "version": 3
509
+ },
510
+ "file_extension": ".py",
511
+ "mimetype": "text/x-python",
512
+ "name": "python",
513
+ "nbconvert_exporter": "python",
514
+ "pygments_lexer": "ipython3",
515
+ "version": "3.8.18"
516
+ },
517
+ "orig_nbformat": 4
518
+ },
519
+ "nbformat": 4,
520
+ "nbformat_minor": 2
521
+ }
fragment_creator.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from dataclasses import dataclass
3
+ from typing import List, Union
4
+ import numpy as np
5
+ from rdkit import Chem
6
+ from rdkit.Chem.BRICS import BRICSDecompose
7
+ from rdkit.Chem.Recap import RecapDecompose
8
+
9
+ import random
10
+
11
+
12
+ @dataclass
13
+ class Fragment:
14
+ smiles: Union[str, None]
15
+ tokens: Union[List[int], None]
16
+
17
+
18
+ class BaseFragmentCreator(ABC):
19
+ """
20
+ Is the base class for all fragment creator and does nothing to the smiles
21
+ """
22
+
23
+ def __init__(self) -> None:
24
+ pass
25
+
26
+ def create_fragment(self, frag: Fragment) -> Fragment:
27
+ return ""
28
+
29
+
30
+ # This is the method used in the paper
31
+ class RandomSubsliceFragmentCreator(BaseFragmentCreator):
32
+ def __init__(self, max_fragment_size=50) -> None:
33
+ super().__init__()
34
+ self.max_fragment_size = max_fragment_size
35
+
36
+ def create_fragment(self, frag: Fragment) -> Fragment:
37
+ """
38
+ Creates the random sub slice fragments from the tokens
39
+ """
40
+ tokens = frag.tokens
41
+
42
+ startIdx = np.random.randint(0, len(tokens) - 1)
43
+
44
+ endIdx = np.random.randint(
45
+ startIdx + 1, min(len(tokens), startIdx + self.max_fragment_size)
46
+ )
47
+ return Fragment(smiles=None, tokens=tokens[startIdx:endIdx])
48
+
49
+
50
+ class BricksFragmentCreator(BaseFragmentCreator):
51
+ def __init__(self) -> None:
52
+ super().__init__()
53
+
54
+ def create_fragment(self, frag: Fragment) -> Fragment:
55
+ """
56
+ Creates the Bricks fragments and takes one randomly
57
+ """
58
+ smiles = frag.smiles
59
+ m = Chem.MolFromSmiles(smiles)
60
+ if m is None:
61
+ return ""
62
+
63
+ res = list(BRICSDecompose(m, minFragmentSize=3))
64
+ # print(res)
65
+ return random.choice(res)
66
+
67
+
68
+ class RecapFragmentCreator(BaseFragmentCreator):
69
+ def __init__(self) -> None:
70
+ super().__init__()
71
+
72
+ def create_fragment(self, frag: Fragment) -> Fragment:
73
+ """
74
+ Creates the Recap fragments and takes one randomly
75
+ """
76
+ smiles = frag.smiles
77
+ m = Chem.MolFromSmiles(smiles)
78
+ if m is None:
79
+ return ""
80
+
81
+ res = RecapDecompose(m, minFragmentSize=3).GetAllChildren()
82
+ # print(res)
83
+ return random.choice(res)
84
+
85
+
86
+ class MolFragsFragmentCreator(BaseFragmentCreator):
87
+ def __init__(self) -> None:
88
+ super().__init__()
89
+
90
+ def create_fragment(self, frag: Fragment) -> Fragment:
91
+ """
92
+ Creates the Bricks fragments and takes one randomly
93
+ """
94
+ smiles = frag.smiles
95
+ m = Chem.MolFromSmiles(smiles)
96
+ if m is None:
97
+ return ""
98
+
99
+ res = list(Chem.rdmolops.GetMolFrags(m, asMols=True))
100
+ res = [Chem.MolToSmiles(m) for m in res]
101
+ # print(res)
102
+ return random.choice(res)
103
+
104
+
105
+ def fragment_creator_factory(key: Union[str, None]):
106
+ if key is None:
107
+ return None
108
+
109
+ if key == "mol_frags":
110
+ return MolFragsFragmentCreator()
111
+ elif key == "recap":
112
+ return RecapFragmentCreator()
113
+ elif key == "bricks":
114
+ return BricksFragmentCreator()
115
+ elif key == "rss":
116
+ return RandomSubsliceFragmentCreator()
117
+ else:
118
+ raise ValueError(f"Do not have factory for the given key: {key}")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ from tokenizer import SmilesTokenizer
123
+
124
+ tokenizer = SmilesTokenizer()
125
+
126
+ creator = BricksFragmentCreator()
127
+ # creator = MolFragsFragmentCreator()
128
+
129
+ # creator = RecapFragmentCreator()
130
+
131
+ frag = creator.create_fragment("CC(=O)NC1=CC=C(C=C1)O")
132
+
133
+ print(frag)
134
+ tokens = tokenizer.encode(frag)
135
+ print(tokens)
136
+ print([tokenizer._convert_id_to_token(t) for t in tokens])
generate_paper_graphs.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
4
+ conda activate FULL_PATH_TO_CONDA/torch2-llamol
5
+
6
+ array=( logp sascore mol_weight )
7
+ # python sample.py --num_samples 20000 --num_samples_per_step 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
8
+ # for i in "${array[@]}"
9
+ # do
10
+ # python sample.py --num_samples 10000 --num_samples_per_step 500 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols "$i" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
11
+ # done
12
+
13
+ # 2 Combinations
14
+ python sample.py --num_samples 1000 --seed 4321 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols logp sascore --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
15
+ python sample.py --num_samples 1000 --seed 4321 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols logp mol_weight --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
16
+ python sample.py --num_samples 1000 --seed 4321 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols sascore mol_weight --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
17
+
18
+ # # # All 3
19
+ # python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols logp sascore mol_weight --kv_caching --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --seed 4312
get_fragment_table.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
4
+ conda activate FULL_PATH_TO_CONDA/torch2-llamol
5
+
6
+
7
+ # context_smiles=("c1ccccc1" "s1cccc1" "C1=CSC=C1" "CC1=CSC=C1" "C1=CC=C2C(=C1)C3=CC=CC=C3S2" "CCO" "CC=O" "CC(=O)OC1=CC=CC=C1C(=O)O" "CC(=O)NC1=CC=C(C=C1)O" "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" "OC(=O)C(C)c1ccc(cc1)CC(C)C" "C1C(=O)NC(=O)NC1=O" "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" "CN1CCC23C4C1CC5=C2C(=C(C=C5)O)OC3C(C=C4)O" "CN1CCC23C4C1CC5=C2C(=C(C=C5)OC)OC3C(=O)CC4")
8
+ # context_smiles=("CN1CCC23C4C1CC5=C2C(=C(C=C5)O)OC3C(C=C4)O" "CN1CC[C@]23[C@@H]4[C@H]1CC5=C2C(=C(C=C5)O)O[C@H]3[C@H](C=C4)O" "CN1CCC23C4C1CC5=C2C(=C(C=C5)OC)OC3C(=O)CC4" "CN1CC[C@]23[C@@H]4[C@H]1CC5=C2C(=C(C=C5)OC)O[C@H]3C(=O)CC4" )
9
+ # context_smiles=("C1=CSC=C1" )
10
+ context_smiles=("C1=CSC=C1" "CC=O" "CC(=O)NC1=CC=C(C=C1)O" "CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
11
+ for smi in "${context_smiles[@]}"; do
12
+ # Only fragment generation
13
+ # output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi")
14
+
15
+ # Fragment and LogP
16
+ # output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" )
17
+
18
+ # Fragment and Sascore
19
+ # output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "sascore" )
20
+
21
+ # Fragment and Mol weight
22
+ # output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "mol_weight" )
23
+
24
+ # Multi Fragment Condition
25
+
26
+ # Logp + Sascore
27
+ # output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" "sascore" )
28
+
29
+
30
+ # Logp + Mol Weight
31
+ # output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" "mol_weight" )
32
+
33
+ # Sascore + Mol Weight
34
+ # output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "sascore" "mol_weight" )
35
+
36
+ # Logp + Sascore + Mol Weight
37
+ output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" "sascore" "mol_weight" )
38
+
39
+
40
+ echo "SMI: $smi"
41
+ echo "----------------------"
42
+ done
model.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import pickle
5
+ import struct
6
+ import inspect
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, Optional, Tuple, List, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from tqdm.auto import tqdm
15
+
16
+ from tokenizer import SmilesTokenizer
17
+
18
+
19
+ @dataclass
20
+ class ModelArgs:
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ norm_eps: float = 1e-5
28
+ max_seq_len: int = 2048
29
+ dropout: float = 0.0
30
+
31
+
32
+ @dataclass
33
+ class ContextArgs:
34
+ context_keys: List[str] = field(default_factory=list)
35
+ context_dims: List[int] = field(default_factory=list)
36
+
37
+
38
+ class RMSNorm(torch.nn.Module):
39
+ def __init__(self, dim: int, eps: float):
40
+ super().__init__()
41
+ self.eps = eps
42
+ self.weight = nn.Parameter(torch.ones(dim))
43
+
44
+ def _norm(self, x):
45
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
46
+
47
+ def forward(self, x):
48
+ output = self._norm(x.float()).type_as(x)
49
+ return output * self.weight
50
+
51
+
52
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
53
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
54
+ t = torch.arange(end, device=freqs.device) # type: ignore
55
+ freqs = torch.outer(t, freqs).float() # type: ignore
56
+ freqs_cos = torch.cos(freqs) # real part
57
+ freqs_sin = torch.sin(freqs) # imaginary part
58
+ return freqs_cos, freqs_sin
59
+
60
+
61
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
62
+ ndim = x.ndim
63
+ assert 0 <= 1 < ndim
64
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
65
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
66
+ return freqs_cis.view(shape)
67
+
68
+
69
+ def apply_rotary_emb(
70
+ xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
71
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
72
+ # reshape xq and xk to match the complex representation
73
+ xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
74
+ xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
75
+
76
+ # reshape freqs_cos and freqs_sin for broadcasting
77
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
78
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
79
+
80
+ # apply rotation using real numbers
81
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
82
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
83
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
84
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
85
+
86
+ # flatten last two dimensions
87
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
88
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
89
+
90
+ return xq_out.type_as(xq), xk_out.type_as(xk)
91
+
92
+
93
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
94
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
95
+ bs, slen, n_kv_heads, head_dim = x.shape
96
+ if n_rep == 1:
97
+ return x
98
+ return (
99
+ x[:, :, :, None, :]
100
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
101
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
102
+ )
103
+
104
+
105
+ class Attention(nn.Module):
106
+ def __init__(self, args: ModelArgs):
107
+ super().__init__()
108
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
109
+ model_parallel_size = 1
110
+ self.n_local_heads = args.n_heads // model_parallel_size
111
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
112
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
113
+ self.head_dim = args.dim // args.n_heads
114
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
115
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
116
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
117
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
118
+ self.attn_dropout = nn.Dropout(args.dropout)
119
+ self.resid_dropout = nn.Dropout(args.dropout)
120
+ self.dropout = args.dropout
121
+ self.cache_hash = None
122
+
123
+ # use flash attention or a manual implementation?
124
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
125
+ if not self.flash:
126
+ print(
127
+ "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
128
+ )
129
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
130
+ mask = torch.triu(mask, diagonal=1)
131
+ self.register_buffer("mask", mask)
132
+
133
+ def forward(
134
+ self,
135
+ x: torch.Tensor,
136
+ freqs_cos: torch.Tensor,
137
+ freqs_sin: torch.Tensor,
138
+ ):
139
+ bsz, seqlen, _ = x.shape
140
+
141
+ # QKV
142
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
143
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
144
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
145
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
146
+
147
+ # RoPE relative positional embeddings
148
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
149
+
150
+ # grouped multiquery attention: expand out keys and values
151
+ xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
152
+ xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
153
+
154
+ # make heads into a batch dimension
155
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
156
+ xk = xk.transpose(1, 2)
157
+ xv = xv.transpose(1, 2)
158
+
159
+ # flash implementation
160
+ if self.flash:
161
+ output = torch.nn.functional.scaled_dot_product_attention(
162
+ xq,
163
+ xk,
164
+ xv,
165
+ attn_mask=None,
166
+ dropout_p=self.dropout if self.training else 0.0,
167
+ is_causal=True,
168
+ )
169
+ else:
170
+ # manual implementation
171
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
172
+ assert hasattr(self, "mask")
173
+ scores = (
174
+ scores + self.mask[:, :, :seqlen, :seqlen]
175
+ ) # (bs, n_local_heads, seqlen, cache_len + seqlen)
176
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
177
+ scores = self.attn_dropout(scores)
178
+ output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
179
+
180
+ # restore time as batch dimension and concat heads
181
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
182
+
183
+ # final projection into the residual stream
184
+ output = self.wo(output)
185
+ output = self.resid_dropout(output)
186
+ return output
187
+
188
+ def forward_with_kvcache(
189
+ self,
190
+ x: torch.Tensor,
191
+ freqs_cos: torch.Tensor,
192
+ freqs_sin: torch.Tensor,
193
+ cache_id: int = 1,
194
+ ):
195
+ bsz, seqlen, _ = x.shape
196
+
197
+ original_x = x
198
+ use_cache = self.cache_hash == cache_id
199
+ if use_cache:
200
+ x = x[:, -1, :].unsqueeze(1) # only need the last new token
201
+ # QKV
202
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
203
+ if use_cache:
204
+ # comp_xq, comp_xk, comp_xv = self.wq(original_x), self.wk(original_x), self.wv(original_x)
205
+ # comp_xq = comp_xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
206
+ # comp_xk = comp_xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
207
+ # comp_xv = comp_xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
208
+
209
+ # # RoPE relative positional embeddings
210
+ # comp_xq, comp_xk = apply_rotary_emb(comp_xq, comp_xk, freqs_cos, freqs_sin)
211
+
212
+ self.k_cache = torch.concat([self.k_cache, xk.clone()], dim=1)
213
+ self.v_cache = torch.concat([self.v_cache, xv.clone()], dim=1)
214
+ # print("Before positional xk:", torch.all(self.k_cache == self.wk(original_x)))
215
+ # print("Before positional xv:", torch.all(self.v_cache == self.wv(original_x)))
216
+
217
+ seqlen = self.k_cache.size(1)
218
+ xk = self.k_cache
219
+ xv = self.v_cache
220
+ self.cache_hash = cache_id
221
+ xq = xq.view(bsz, 1, self.n_local_heads, self.head_dim)
222
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
223
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
224
+
225
+ # RoPE relative positional embeddings
226
+ # xq, xk = apply_rotary_emb(xq, xk[:,-1,:,:].unsqueeze(1), freqs_cos[-1,:].unsqueeze(0), freqs_sin[-1,:].unsqueeze(0))
227
+ # reshape xq and xk to match the complex representation
228
+ xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
229
+ xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
230
+
231
+ # reshape freqs_cos and freqs_sin for broadcasting
232
+ q_freq_cos = freqs_cos[-1, :].unsqueeze(0)
233
+ q_freq_sin = freqs_sin[-1, :].unsqueeze(0)
234
+ freqs_cos_q = reshape_for_broadcast(q_freq_cos, xq_r)
235
+ freqs_sin_q = reshape_for_broadcast(q_freq_sin, xq_r)
236
+
237
+ freqs_cos_k = reshape_for_broadcast(freqs_cos, xk_r)
238
+ freqs_sin_k = reshape_for_broadcast(freqs_sin, xk_r)
239
+
240
+ # apply rotation using real numbers
241
+ xq_out_r = xq_r * freqs_cos_q - xq_i * freqs_sin_q
242
+ xq_out_i = xq_r * freqs_sin_q + xq_i * freqs_cos_q
243
+ xk_out_r = xk_r * freqs_cos_k - xk_i * freqs_sin_k
244
+ xk_out_i = xk_r * freqs_sin_k + xk_i * freqs_cos_k
245
+
246
+ # flatten last two dimensions
247
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
248
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
249
+
250
+ xq, xk = xq_out.type_as(xq), xk_out.type_as(xk)
251
+ # print(f"Seq len {xk.shape[1]} xq:", torch.allclose(xq , comp_xq[:,-1,:].unsqueeze(1), atol=1e-7), torch.mean(xq - comp_xq[:,-1,:].unsqueeze(1)))
252
+ # print(f"Seq len {xk.shape[1]} xk:", torch.allclose(xk ,comp_xk, atol=1e-7), torch.mean(xk - comp_xk))
253
+ # print(f"Seq len {xk.shape[1]} xv:", torch.allclose(xv , comp_xv, atol=1e-7), torch.mean(xv - comp_xv))
254
+ # print("-"*10)
255
+ # self.old_x = original_x
256
+ else:
257
+ self.k_cache = xk
258
+ self.v_cache = xv
259
+ self.old_x = x
260
+
261
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
262
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
263
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
264
+
265
+ self.cache_hash = cache_id
266
+
267
+ # RoPE relative positional embeddings
268
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
269
+
270
+ # grouped multiquery attention: expand out keys and values
271
+ xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
272
+ xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
273
+
274
+ # make heads into a batch dimension
275
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
276
+ xk = xk.transpose(1, 2)
277
+ xv = xv.transpose(1, 2)
278
+
279
+ # flash implementation
280
+ if self.flash:
281
+ output = torch.nn.functional.scaled_dot_product_attention(
282
+ xq,
283
+ xk,
284
+ xv,
285
+ attn_mask=None,
286
+ dropout_p=self.dropout if self.training else 0.0,
287
+ # NOTE: VERY IMPORTANT to set is_causal=False, OTHERWISE the KV-Caching just breaks
288
+ is_causal=False,
289
+ )
290
+ else:
291
+ # manual implementation
292
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
293
+ assert hasattr(self, "mask")
294
+ scores = (
295
+ scores + self.mask[:, :, :seqlen, :seqlen]
296
+ ) # (bs, n_local_heads, seqlen, cache_len + seqlen)
297
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
298
+ scores = self.attn_dropout(scores)
299
+ output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
300
+
301
+ # restore time as batch dimension and concat heads
302
+ # if use_cache:
303
+ # # original_x[:,-1,:] = output.transpose(1, 2).contiguous().view(bsz,-1)
304
+ # # output = original_x
305
+ # output = torch.concat( [self.out_cache, output.transpose(1, 2).view(bsz,1,-1)], dim=1).contiguous()
306
+ # self.out_cache = output
307
+ # else:
308
+ # output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
309
+ # self.out_cache = output
310
+
311
+ # NOTE: only work when fed in one token at a time (e.g. seq = 1)
312
+ output = output.transpose(1, 2).contiguous().view(bsz, x.size(1), -1)
313
+
314
+ # final projection into the residual stream
315
+ output = self.wo(output)
316
+ output = self.resid_dropout(output)
317
+ return output
318
+
319
+
320
+ class FeedForward(nn.Module):
321
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
322
+ super().__init__()
323
+ hidden_dim = int(2 * hidden_dim / 3)
324
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
325
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
326
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
327
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
328
+ self.dropout = nn.Dropout(dropout)
329
+
330
+ def forward(self, x):
331
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
332
+
333
+
334
+ class TransformerBlock(nn.Module):
335
+ def __init__(self, layer_id: int, args: ModelArgs):
336
+ super().__init__()
337
+ self.n_heads = args.n_heads
338
+ self.dim = args.dim
339
+ self.head_dim = args.dim // args.n_heads
340
+ self.attention = Attention(args)
341
+ self.feed_forward = FeedForward(
342
+ dim=args.dim,
343
+ hidden_dim=4 * args.dim,
344
+ multiple_of=args.multiple_of,
345
+ dropout=args.dropout,
346
+ )
347
+ self.layer_id = layer_id
348
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
349
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
350
+
351
+ def forward(self, x, freqs_cos, freqs_sin):
352
+ h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
353
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
354
+ return out
355
+
356
+ def forward_with_kvcache(self, x, freqs_cos, freqs_sin, cache_id=1):
357
+ h = x + self.attention.forward_with_kvcache(
358
+ self.attention_norm(x), freqs_cos, freqs_sin, cache_id=cache_id
359
+ )
360
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
361
+ return out
362
+
363
+
364
+ class Transformer(nn.Module):
365
+ last_loss: Optional[torch.Tensor]
366
+
367
+ def __init__(self, params: ModelArgs, context_params: ContextArgs):
368
+ super().__init__()
369
+ self.params = params
370
+ self.context_params = context_params
371
+ self.vocab_size = params.vocab_size
372
+ self.n_layers = params.n_layers
373
+
374
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
375
+
376
+ self.frag_embeddings = nn.Embedding(params.vocab_size, params.dim)
377
+ self.frag_type_embedding = nn.Embedding(1, params.dim)
378
+
379
+ self.context_lookup = {k: i for i, k in enumerate(context_params.context_keys)}
380
+ self.conditions_type_embeddings = nn.Embedding(
381
+ len(context_params.context_keys), params.dim
382
+ )
383
+ self.conditions_embeddings_lookup = nn.ModuleDict(
384
+ {
385
+ k: nn.Sequential(
386
+ nn.Linear(dim, params.dim, bias=True),
387
+ )
388
+ for k, dim in zip(
389
+ context_params.context_keys, context_params.context_dims
390
+ )
391
+ }
392
+ )
393
+
394
+ self.dropout = nn.Dropout(params.dropout)
395
+ self.layers = torch.nn.ModuleList()
396
+ for layer_id in range(params.n_layers):
397
+ self.layers.append(TransformerBlock(layer_id, params))
398
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
399
+ self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
400
+
401
+ # share the unembedding parameters with the embedding parameters
402
+ self.tok_embeddings.weight = (
403
+ self.output.weight
404
+ ) # https://paperswithcode.com/method/weight-tying
405
+
406
+ # some useful precompute for the RoPE relative positional embeddings
407
+ freqs_cos, freqs_sin = precompute_freqs_cis(
408
+ self.params.dim // self.params.n_heads, self.params.max_seq_len
409
+ )
410
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
411
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
412
+
413
+ # init all weights
414
+ self.apply(self._init_weights)
415
+ # apply special scaled init to the residual projections, per GPT-2 paper
416
+ for pn, p in self.named_parameters():
417
+ if pn.endswith("w3.weight") or pn.endswith("wo.weight"):
418
+ torch.nn.init.normal_(
419
+ p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers)
420
+ )
421
+
422
+ # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
423
+ self.last_loss = None
424
+
425
+ def _init_weights(self, module):
426
+ if isinstance(module, nn.Linear):
427
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
428
+ if module.bias is not None:
429
+ torch.nn.init.zeros_(module.bias)
430
+ elif isinstance(module, nn.Embedding):
431
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
432
+
433
+ def forward(
434
+ self,
435
+ tokens: torch.Tensor,
436
+ targets: Optional[torch.Tensor] = None,
437
+ context: Optional[Dict[str, torch.Tensor]] = None,
438
+ fragment: Optional[torch.Tensor] = None,
439
+ ) -> torch.Tensor:
440
+ bsz, seqlen = tokens.shape
441
+ device = tokens.device
442
+
443
+ h = self._add_context_to_seq(tokens, context, fragment, bsz, device)
444
+
445
+ context_seq_len = h.shape[1] - seqlen
446
+
447
+ bsz, seqlen, _ = h.shape
448
+
449
+ freqs_cos = self.freqs_cos[:seqlen]
450
+ freqs_sin = self.freqs_sin[:seqlen]
451
+
452
+ for layer in self.layers:
453
+ h = layer(h, freqs_cos, freqs_sin)
454
+ h = self.norm(h)
455
+
456
+ h = h[:, context_seq_len:]
457
+ if targets is not None:
458
+ # if we are given some desired targets also calculate the loss
459
+ logits = self.output(h)
460
+ tmp_last_loss = F.cross_entropy(
461
+ logits.reshape(-1, logits.size(-1)),
462
+ targets.reshape(-1),
463
+ ignore_index=0, # Ignore Pad Tokens
464
+ )
465
+
466
+ # NOTE: This essentially does nothing for the computation,
467
+ # because we are multiplying the weights by zero.
468
+ # This *needs* to be done, so that we can train with DDP
469
+ # As due to the random training process some of the weights are not used in the forward pass
470
+ # That is unacceptable for the for the c10 backend and the training errors out.
471
+ # Maybe there is a better fix in the future, see:
472
+ # https://github.com/pytorch/pytorch/issues/43259
473
+ ddp_fix = sum(p.sum() for p in self.parameters())
474
+ zero_sum = ddp_fix * 0.0
475
+
476
+ self.last_loss = tmp_last_loss + zero_sum
477
+ else:
478
+ # inference-time mini-optimization: only forward the output on the very last position
479
+ logits = self.output(
480
+ h[:, [-1], :]
481
+ ) # note: using list [-1] to preserve the time dim
482
+ self.last_loss = None
483
+
484
+ return logits
485
+
486
+ def forward_with_kvcache(
487
+ self,
488
+ tokens: torch.Tensor,
489
+ targets: Optional[torch.Tensor] = None,
490
+ context: Optional[Dict[str, torch.Tensor]] = None,
491
+ fragment: Optional[torch.Tensor] = None,
492
+ cache_id: int = 1,
493
+ pos_seq_len: Optional[int] = None,
494
+ ) -> torch.Tensor:
495
+ bsz, seqlen = tokens.shape
496
+ device = tokens.device
497
+
498
+ h = self._add_context_to_seq(tokens, context, fragment, bsz, device)
499
+
500
+ context_seq_len = h.shape[1] - seqlen
501
+
502
+ bsz, seqlen, _ = h.shape
503
+ if pos_seq_len is None:
504
+ pos_seq_len = seqlen
505
+ else:
506
+ pos_seq_len = max(seqlen, pos_seq_len + context_seq_len)
507
+
508
+ freqs_cos = self.freqs_cos[:pos_seq_len]
509
+ freqs_sin = self.freqs_sin[:pos_seq_len]
510
+
511
+ for layer in self.layers:
512
+ h = layer.forward_with_kvcache(h, freqs_cos, freqs_sin, cache_id=cache_id)
513
+ h = self.norm(h)
514
+
515
+ h = h[:, context_seq_len:]
516
+ if targets is not None:
517
+ # if we are given some desired targets also calculate the loss
518
+ logits = self.output(h)
519
+ tmp_last_loss = F.cross_entropy(
520
+ logits.reshape(-1, logits.size(-1)),
521
+ targets.reshape(-1),
522
+ ignore_index=0, # Ignore Pad Tokens
523
+ )
524
+
525
+ # NOTE: This essentially does nothing for the computation,
526
+ # because we are multiplying the weights by zero.
527
+ # This *needs* to be done, so that we can train with DDP
528
+ # As due to the random training process some of the weights are not used in the forward pass
529
+ # That is unacceptable for the for the c10 backend and the training errors out.
530
+ # Maybe there is a better fix in the future, see:
531
+ # https://github.com/pytorch/pytorch/issues/43259
532
+ ddp_fix = sum(p.sum() for p in self.parameters())
533
+ zero_sum = ddp_fix * 0.0
534
+
535
+ self.last_loss = tmp_last_loss + zero_sum
536
+ else:
537
+ # inference-time mini-optimization: only forward the output on the very last position
538
+ logits = self.output(
539
+ h[:, [-1], :]
540
+ ) # note: using list [-1] to preserve the time dim
541
+ self.last_loss = None
542
+
543
+ return logits
544
+
545
+ def _add_context_to_seq(self, tokens, context, fragment, bsz, device):
546
+ h = self.tok_embeddings(tokens)
547
+ h = self.dropout(h)
548
+
549
+ if fragment is not None:
550
+ fragment_type_enc = torch.zeros_like(
551
+ fragment, dtype=torch.long, device=device
552
+ )
553
+
554
+ h = torch.concat(
555
+ (
556
+ self.tok_embeddings(fragment)
557
+ + self.frag_embeddings(fragment)
558
+ + self.frag_type_embedding(fragment_type_enc),
559
+ h,
560
+ ),
561
+ dim=1,
562
+ )
563
+
564
+ if context is not None and len(context) != 0:
565
+ # context is a dictionary with key : context_tensor of shape (batch_size, context_dim)
566
+ type_ids = []
567
+ context_vals = []
568
+
569
+ for emb_key, context_val in context.items():
570
+ emb_context_val = self.conditions_embeddings_lookup[emb_key](
571
+ context_val.unsqueeze(1).to(device)
572
+ ).unsqueeze(1)
573
+
574
+ context_vals.append(emb_context_val)
575
+ type_ids_tensor = torch.tensor(
576
+ [self.context_lookup[emb_key]], device=device, dtype=torch.long
577
+ )
578
+ type_ids.append(type_ids_tensor)
579
+
580
+ context_types = (
581
+ torch.concat(type_ids, dim=0).reshape(-1, 1).expand(-1, bsz).T
582
+ )
583
+ # shape(len(context),batch_size, emb_size)
584
+ context_types = self.conditions_type_embeddings(context_types)
585
+
586
+ context_vals = torch.concat(context_vals, dim=1).to(device)
587
+
588
+ # SHAPE
589
+ h = torch.concat([context_vals + context_types, h], dim=1)
590
+ return h
591
+
592
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
593
+ # start with all of the candidate parameters
594
+ param_dict = {pn: p for pn, p in self.named_parameters()}
595
+ # filter out those that do not require grad
596
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
597
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
598
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
599
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
600
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
601
+ optim_groups = [
602
+ {"params": decay_params, "weight_decay": weight_decay},
603
+ {"params": nodecay_params, "weight_decay": 0.0},
604
+ ]
605
+ num_decay_params = sum(p.numel() for p in decay_params)
606
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
607
+ print(
608
+ f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
609
+ )
610
+ print(
611
+ f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
612
+ )
613
+ # Create AdamW optimizer and use the fused version if it is available
614
+ fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
615
+ use_fused = fused_available and device_type == "cuda"
616
+ extra_args = dict(fused=True) if use_fused else dict()
617
+ optimizer = torch.optim.AdamW(
618
+ optim_groups, lr=learning_rate, betas=betas, **extra_args
619
+ )
620
+ print(f"using fused AdamW: {use_fused}")
621
+
622
+ return optimizer
623
+
624
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
625
+ """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
626
+ # first estimate the number of flops we do per iteration.
627
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
628
+ N = sum(p.numel() for p in self.parameters())
629
+ cfg = self.params
630
+ L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim // cfg.n_heads, cfg.max_seq_len
631
+ flops_per_token = 6 * N + 12 * L * H * Q * T
632
+ flops_per_fwdbwd = flops_per_token * T
633
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
634
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
635
+ flops_achieved = flops_per_iter * (1.0 / dt) # per second
636
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
637
+ mfu = flops_achieved / flops_promised
638
+ return mfu
639
+
640
+ @torch.inference_mode()
641
+ def generate(
642
+ self,
643
+ tokenizer: SmilesTokenizer,
644
+ context: Union[torch.Tensor, None] = None,
645
+ fragments: Union[torch.Tensor, None] = None,
646
+ max_length: int = 50,
647
+ num_gen: int = 200,
648
+ start_smiles: Union[str, None] = None,
649
+ temperature: float = 1.0,
650
+ top_k: Union[int, None] = None,
651
+ device: torch.device = torch.device("cpu"),
652
+ cache_kv: bool = False,
653
+ ) -> List[str]:
654
+ batch_size = num_gen
655
+ if start_smiles is not None:
656
+ tokenized_start_selfie = tokenizer.encode(start_smiles)[
657
+ :-1
658
+ ] # remove <eos> token
659
+ tokenized_start_selfie = torch.tensor(
660
+ tokenized_start_selfie, device=device, dtype=torch.long
661
+ ).view(-1, 1)
662
+ tokenized_start_selfie = tokenized_start_selfie.repeat(1, batch_size)
663
+
664
+ outputs = tokenized_start_selfie.T
665
+ else:
666
+ outputs = (
667
+ torch.LongTensor([[tokenizer.cls_token_id] * batch_size]).to(device)
668
+ ).T # batch_size
669
+ self.eval()
670
+
671
+ start_len = outputs.shape[1]
672
+ has_end_idx = np.array([0] * batch_size)
673
+ cache_id = np.random.randint(0, int(1e10), 1).item()
674
+ with torch.no_grad():
675
+ with tqdm(total=max_length, desc="Generation") as pbar:
676
+ for i in range(start_len, max_length):
677
+ # trg_tensor = #torch.LongTensor(outputs).to(model.device)
678
+ if not cache_kv:
679
+ logits = self(outputs, context=context, fragment=fragments)
680
+ else:
681
+ # logits_ = self(outputs, context=context, fragment=fragments)
682
+ if i == start_len:
683
+ # When starting pass the whole input, so that "start_smiles" works, then only the newly generated token, because of the cache
684
+ func_input = outputs
685
+ else:
686
+ func_input = outputs[:, -1].unsqueeze(-1)
687
+ logits = self.forward_with_kvcache(
688
+ func_input,
689
+ context=context,
690
+ fragment=fragments,
691
+ cache_id=cache_id,
692
+ pos_seq_len=outputs.size(-1),
693
+ )
694
+
695
+ # raise NotImplementedError("Currently not working / right implemented")
696
+ # logits = self.forward_with_kvcache(outputs, context=context, fragment=fragments,cache_id = cache_id)
697
+
698
+ logits = logits[:, -1, :] # crop to just the final time step
699
+ if temperature == 0.0:
700
+ # "sample" the single most likely index
701
+ _, logits = torch.topk(logits, k=1, dim=-1)
702
+ else:
703
+ # pluck the logits at the final step and scale by desired temperature
704
+ logits = logits / temperature
705
+ # optionally crop the logits to only the top k options
706
+ if top_k is not None:
707
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
708
+ logits[logits < v[:, [-1]]] = -float("Inf")
709
+
710
+ probs = F.softmax(logits, dim=-1)
711
+ idx_next = torch.multinomial(probs, num_samples=1)
712
+
713
+ ended_sentences = idx_next == tokenizer.sep_token_id
714
+ if torch.count_nonzero(ended_sentences) != 0:
715
+ indicies = torch.nonzero(ended_sentences)
716
+ indicies = indicies.cpu().numpy()
717
+ for end_idx in indicies[:, 0]:
718
+ if has_end_idx[end_idx] == 0:
719
+ has_end_idx[end_idx] = i
720
+
721
+ # print(has_end_idx)
722
+
723
+ if all([idx != 0 for idx in has_end_idx]):
724
+ break
725
+
726
+ # outputs.append(best_guesses)
727
+ # outputs = torch.row_stack((outputs, idx_next))
728
+ outputs = torch.cat((outputs, idx_next), dim=1)
729
+ pbar.update(1)
730
+
731
+ out_selfies = []
732
+ for output, end_idx in zip(outputs.cpu().numpy(), has_end_idx):
733
+ # Incase of limiting the max_len
734
+ if end_idx == 0:
735
+ selfie = [tokenizer._convert_id_to_token(idx) for idx in output[:]]
736
+ else:
737
+ selfie = [
738
+ tokenizer._convert_id_to_token(idx) for idx in output[:end_idx]
739
+ ]
740
+ selfie = "".join(selfie[1:])
741
+ out_selfies.append(selfie)
742
+
743
+ # for indicies in outputs:
744
+ # translated_sentence = [tokenizer.idx_to_tokens[idx] for idx in outputs]
745
+ # remove start token
746
+ return out_selfies
747
+
748
+ @staticmethod
749
+ def load(path, device: torch.device = torch.device("cpu")) -> Transformer:
750
+ data = torch.load(path, map_location=device)
751
+
752
+ newinstace = Transformer(data["model_params"], data["context_params"])
753
+ newinstace.load_state_dict(data["state_dict"])
754
+ return newinstace.to(device)
755
+
756
+ def save(self, filepath):
757
+ torch.save(
758
+ {
759
+ "state_dict": self.state_dict(),
760
+ **dict(model_params=self.params, context_params=self.context_params),
761
+ },
762
+ filepath,
763
+ )
764
+
765
+ def getNumberTrainableParams(self) -> int:
766
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
767
+
768
+ def getNumberParams(self) -> int:
769
+ return sum(p.numel() for p in self.parameters())
770
+
771
+
772
+ if __name__ == "__main__":
773
+ m = Transformer(
774
+ ModelArgs(dim=128, n_layers=8, n_heads=8, vocab_size=512, max_seq_len=1024),
775
+ context_params=ContextArgs(
776
+ context_keys=["logp", "sascore", "mol_weight"], context_dims=[1, 1, 1]
777
+ ),
778
+ )
779
+ seq = torch.ones((128, 50), dtype=torch.long)
780
+ frag = torch.ones((128, 10), dtype=torch.long)
781
+ context = {
782
+ "logp": torch.ones((128,), dtype=torch.float32),
783
+ # "sascore": torch.ones((128,), dtype=torch.float32),
784
+ "mol_weight": torch.ones((128,), dtype=torch.float32),
785
+ }
786
+
787
+ print(m.forward(seq, targets=seq, context=context, fragment=frag))
out/llama2-M-Full-RSS.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83571f8f8936a4eac8ac4541282ff99a3e942c07ee4aaef82abdc2f52e1731ae
3
+ size 58587134
plot_utils.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Union
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from rdkit.Chem import AllChem, Descriptors, RDConfig
9
+
10
+ import sys
11
+
12
+ sys.path.append(os.path.join(RDConfig.RDContribDir, "SA_Score"))
13
+ # now you can import sascore!
14
+ import sascorer
15
+ from rdkit import Chem
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+ # plt.rcParams.update({'font.size': 13.1})
20
+ plt.rcParams.update({"font.size": 12.5})
21
+
22
+ COL_TO_DISPLAY_NAME = {
23
+ "logp": "LogP",
24
+ "sascore": "SAScore",
25
+ "mol_weight": "Molecular Weight",
26
+ }
27
+
28
+
29
+ def calcContextSAScore(smiles: List[str]):
30
+ sasc = []
31
+ for smi in smiles:
32
+ mol = Chem.MolFromSmiles(smi)
33
+ sa = sascorer.calculateScore(mol)
34
+ sasc.append(sa)
35
+
36
+ return np.array(sasc)
37
+
38
+
39
+ def calcContextLogP(smiles: List[str]):
40
+ logps = []
41
+ for smi in smiles:
42
+ mol = Chem.MolFromSmiles(smi)
43
+ logp = Descriptors.MolLogP(mol)
44
+ logps.append(logp)
45
+
46
+ return np.array(logps)
47
+
48
+
49
+ def calcContextEnergy(smiles, num_confs=5):
50
+ contexts = []
51
+ for smi in smiles:
52
+ # print("Calculating Energy:",smi)
53
+ mol = Chem.AddHs(Chem.MolFromSmiles(smi))
54
+ AllChem.EmbedMultipleConfs(mol, num_confs, numThreads=48)
55
+ generated_smiles = AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=48)
56
+ energies = []
57
+ for coverged, energy in generated_smiles:
58
+ if coverged != 0:
59
+ print("Not converged!", smi)
60
+ energies.append(energy)
61
+
62
+ # print(energy)
63
+ # kcal/mol
64
+ mean_en = np.mean(energies)
65
+ # to hartree
66
+ mean_en = mean_en * 0.0016
67
+ contexts.append(mean_en)
68
+
69
+ return np.array(contexts)
70
+
71
+
72
+ def calcContextMolWeight(smiles: List[str]):
73
+ con = []
74
+ for _, smi in enumerate(smiles):
75
+ mol = Chem.MolFromSmiles(smi)
76
+ c = Descriptors.ExactMolWt(mol) / 100
77
+ con.append(c)
78
+
79
+ return np.array(con)
80
+
81
+
82
+ def plot_1D_condition(
83
+ context_col,
84
+ save_path,
85
+ new_context,
86
+ generated_smiles,
87
+ temperature,
88
+ context_dict,
89
+ context_scaler=None,
90
+ ):
91
+ for con_col in context_col:
92
+ save_path = os.path.join(
93
+ save_path, f"{con_col}_{'-'.join(context_col)}_temp{temperature}"
94
+ )
95
+ os.makedirs(save_path, exist_ok=True)
96
+
97
+ current_context = new_context[con_col].cpu().detach().numpy()
98
+ if con_col == "mol_weight":
99
+ predicted_context = calcContextMolWeight(generated_smiles)
100
+ elif con_col == "logp":
101
+ predicted_context = calcContextLogP(generated_smiles)
102
+ elif con_col == "sascore":
103
+ predicted_context = calcContextSAScore(generated_smiles)
104
+ elif con_col == "energy":
105
+ # TODO: Change to something better
106
+ predicted_context = calcContextEnergy(generated_smiles)
107
+
108
+ if context_scaler is not None:
109
+ raise NotImplementedError("Not implemented yet")
110
+ # context_list = context_scaler.inverse_transform(context_list)
111
+
112
+ mean_vals_pred = []
113
+ labels = np.unique(current_context)
114
+ mse_value = []
115
+ mad_value = []
116
+ for label in labels:
117
+ mask = (current_context == label).reshape(-1)
118
+ mean_val = np.mean(predicted_context[mask])
119
+ mean_vals_pred.append(mean_val)
120
+ mse_value.extend((predicted_context[mask] - label) ** 2)
121
+ mad_value.extend(abs(predicted_context[mask] - label))
122
+
123
+ mse = np.mean(mse_value)
124
+ mad = np.mean(mad_value)
125
+ logger.info(f"MSE {mse}")
126
+ logger.info(f"MAD {mad}")
127
+ logger.info(f"SD: {np.std(mad_value)}")
128
+
129
+ current_context = current_context.reshape(-1)
130
+
131
+ # Create a figure and axes
132
+ fig, ax1 = plt.subplots()
133
+
134
+ # Scatter plot
135
+ ax1.scatter(
136
+ current_context,
137
+ predicted_context,
138
+ label="Ground Truth vs Prediction",
139
+ c="blue",
140
+ alpha=0.5,
141
+ )
142
+ ax1.plot(
143
+ np.arange(np.min(current_context), np.max(current_context) + 1),
144
+ np.arange(np.min(current_context), np.max(current_context) + 1),
145
+ label="y=x",
146
+ c="black",
147
+ )
148
+ ax1.scatter(labels, mean_vals_pred, label="Mean predicted values", c="red")
149
+ ax1.set_xlabel("Ground Truth")
150
+ ax1.set_ylabel("Prediction")
151
+
152
+ # Histogram
153
+ ax2 = ax1.twinx() # Create a twin Axes sharing the x-axis
154
+ sns.histplot(
155
+ context_dict[con_col],
156
+ # bins=200,
157
+ label="Dataset distribution",
158
+ alpha=0.5,
159
+ # kde=True,
160
+ # element="poly",
161
+ ax=ax2,
162
+ )
163
+ # ax2.hist(
164
+ # context_dict[con_col],
165
+ # bins=200,
166
+ # label="Dataset distribution",
167
+ # alpha=0.5,
168
+ # )
169
+ ax2.set_ylabel("Frequency")
170
+
171
+ # Combine legends
172
+ handles1, labels1 = ax1.get_legend_handles_labels()
173
+ handles2, labels2 = ax2.get_legend_handles_labels()
174
+
175
+ ax1.legend(handles1 + handles2, labels1 + labels2)
176
+
177
+ plt.xlim((np.min(current_context), np.max(current_context) + 1))
178
+ # Set title
179
+ display_name = COL_TO_DISPLAY_NAME[con_col]
180
+ plt.title(f"{display_name} - temperature: {temperature} - mse: {round(mse, 4)}")
181
+
182
+ out_df = pd.DataFrame(
183
+ {
184
+ "smiles": generated_smiles,
185
+ f"{con_col}": predicted_context.tolist(),
186
+ f"target_{con_col}": current_context.tolist(),
187
+ }
188
+ )
189
+ out_df.to_csv(os.path.join(save_path, "predictions.csv"), index=False)
190
+ out_path = os.path.join(save_path, "graph.png")
191
+ print(f"Saved to {out_path}")
192
+ plt.savefig(out_path)
193
+ plt.clf()
194
+
195
+
196
+ def plot_2D_condition(
197
+ context_col,
198
+ save_path,
199
+ new_context,
200
+ generated_smiles,
201
+ temperature,
202
+ label: Union[str, None] = None,
203
+ ):
204
+ save_path = os.path.join(
205
+ save_path, f"multicond2_{'-'.join(context_col)}_temp={temperature}"
206
+ )
207
+ if label is not None:
208
+ save_path = os.path.join(save_path, label)
209
+
210
+ os.makedirs(save_path, exist_ok=True)
211
+ delta_dict = {c: [] for c in context_col}
212
+ predicted_context_dict = {}
213
+ for con_col in context_col:
214
+ current_context = new_context[con_col].cpu().numpy()
215
+ if con_col == "mol_weight":
216
+ predicted_context = calcContextMolWeight(generated_smiles)
217
+ elif con_col == "logp":
218
+ predicted_context = calcContextLogP(generated_smiles)
219
+ elif con_col == "sascore":
220
+ predicted_context = calcContextSAScore(generated_smiles)
221
+ elif con_col == "energy":
222
+ # TODO: Change to something better
223
+ predicted_context = calcContextEnergy(generated_smiles)
224
+
225
+ predicted_context_dict[con_col] = np.array(predicted_context)
226
+ delta_dict[con_col] = np.abs(current_context - np.array(predicted_context))
227
+
228
+ # Create a DataFrame from delta_dict
229
+ df = pd.DataFrame(delta_dict)
230
+ real_values_prop1 = new_context[context_col[0]].cpu().numpy()
231
+ real_values_prop2 = new_context[context_col[1]].cpu().numpy()
232
+ # cmap = plt.get_cmap('Blues') # Choose a green color palette from Matplotlib
233
+ mse_vals_x = []
234
+ mad_vals_x = []
235
+ mse_vals_y = []
236
+ mad_vals_y = []
237
+ fig = plt.figure()
238
+ ax = plt.subplot(111)
239
+ for v1 in np.unique(real_values_prop1):
240
+ for v2 in np.unique(real_values_prop2):
241
+ mask = (real_values_prop1 == v1) & (real_values_prop2 == v2)
242
+ indices = np.nonzero(mask)[0]
243
+ # print("Indices", len(indices))
244
+ # Get the color from the color palette based on the v1 value
245
+ # color = cmap((v1 - np.min(real_values_prop1)) / (np.max(real_values_prop1) - np.min(real_values_prop1)))
246
+ color = np.random.rand(
247
+ 3,
248
+ )
249
+ # # Plot scatter plot with the specified color and label
250
+
251
+ x_pred = predicted_context_dict[context_col[0]][indices].ravel()
252
+ y_pred = predicted_context_dict[context_col[1]][indices].ravel()
253
+ mse_vals_x.extend((x_pred - v1) ** 2)
254
+ mad_vals_x.extend(np.abs(x_pred - v1))
255
+
256
+ mse_vals_y.extend((y_pred - v2) ** 2)
257
+ mad_vals_y.extend(np.abs(y_pred - v2))
258
+
259
+ ax.scatter(x_pred, y_pred, color=color, alpha=0.5)
260
+
261
+ # Plot KDE plot with the specified color
262
+ # sns.kdeplot(
263
+ # data=pd.DataFrame(
264
+ # {
265
+ # f"x": x_pred,
266
+ # f"y": y_pred,
267
+ # }
268
+ # ),
269
+ # x=f"x",
270
+ # y=f"y",
271
+ # color=color,
272
+ # fill=False,
273
+ # bw_adjust=2.25,
274
+ # # label=f"({v1}, {v2})"
275
+ # )
276
+
277
+ ax.scatter(v1, v2, color=color, label=f"({v1}, {v2})", marker="^", s=20.0)
278
+
279
+ mse_x = np.mean(mse_vals_x)
280
+ mad_x = np.mean(mad_vals_x)
281
+ mse_y = np.mean(mse_vals_y)
282
+ mad_y = np.mean(mad_vals_y)
283
+
284
+ logger.info(f"MSE {context_col[0]}: {mse_x}")
285
+ logger.info(f"MAD {context_col[0]}: {mad_x}")
286
+ logger.info(f"MSE {context_col[1]}: {mse_y}")
287
+ logger.info(f"MAD {context_col[1]}: {mad_y}")
288
+
289
+ file_path = os.path.join(save_path, "metrics.txt")
290
+
291
+ with open(file_path, "w") as f:
292
+ f.write(f"MSE {context_col[0]}: {mse_x} \n")
293
+ f.write(f"MAD {context_col[0]}: {mad_x} \n")
294
+ f.write(f"MSE {context_col[1]}: {mse_y} \n")
295
+ f.write(f"MAD {context_col[1]}: {mad_y} \n")
296
+
297
+ ax.set_xlabel(COL_TO_DISPLAY_NAME[context_col[0]])
298
+ ax.set_ylabel(COL_TO_DISPLAY_NAME[context_col[1]])
299
+ box = ax.get_position()
300
+ ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
301
+
302
+ # Put a legend to the right of the current axis
303
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
304
+ ax.set_title("Multi Property Distribution of Generated Molecules")
305
+ out_path = os.path.join(save_path, "graph.png")
306
+ logger.info(f"Saved to {out_path}")
307
+ plt.savefig(out_path)
308
+ plt.clf()
309
+ return save_path
310
+
311
+
312
+ def plot_3D_condition(
313
+ context_col, save_path, new_context, generated_smiles, temperature
314
+ ):
315
+ save_path = os.path.join(
316
+ save_path, f"multicond3_{'-'.join(context_col)}_temp={temperature}"
317
+ )
318
+ os.makedirs(save_path, exist_ok=True)
319
+ predicted_context_dict = {}
320
+ for con_col in context_col:
321
+ predicted_context = calc_context_from_smiles(generated_smiles, con_col)
322
+
323
+ predicted_context_dict[con_col] = np.array(predicted_context)
324
+
325
+ real_values_prop1 = new_context[context_col[0]].cpu().numpy()
326
+ real_values_prop2 = new_context[context_col[1]].cpu().numpy()
327
+ real_values_prop3 = new_context[context_col[2]].cpu().numpy()
328
+ # cmap = plt.get_cmap('Blues') # Choose a green color palette from Matplotlib
329
+
330
+ mse_vals_x = []
331
+ mad_vals_x = []
332
+ mse_vals_y = []
333
+ mad_vals_y = []
334
+ mse_vals_z = []
335
+ mad_vals_z = []
336
+
337
+ fig = plt.figure()
338
+ ax = fig.add_subplot(projection="3d")
339
+ for v1 in np.unique(real_values_prop1):
340
+ for v2 in np.unique(real_values_prop2):
341
+ for v3 in np.unique(real_values_prop3):
342
+ mask = (
343
+ (real_values_prop1 == v1)
344
+ & (real_values_prop2 == v2)
345
+ & (real_values_prop3 == v3)
346
+ )
347
+ indices = np.nonzero(mask)[0]
348
+ # print("Indices", len(indices))
349
+ # Get the color from the color palette based on the v1 value
350
+ # color = cmap((v1 - np.min(real_values_prop1)) / (np.max(real_values_prop1) - np.min(real_values_prop1)))
351
+ color = np.random.rand(
352
+ 3,
353
+ )
354
+
355
+ x_pred = predicted_context_dict[context_col[0]][indices].ravel()
356
+ y_pred = predicted_context_dict[context_col[1]][indices].ravel()
357
+ z_pred = predicted_context_dict[context_col[2]][indices].ravel()
358
+
359
+ mse_vals_x.extend((x_pred - v1) ** 2)
360
+ mad_vals_x.extend(np.abs(x_pred - v1))
361
+
362
+ mse_vals_y.extend((y_pred - v2) ** 2)
363
+ mad_vals_y.extend(np.abs(y_pred - v2))
364
+
365
+ mse_vals_z.extend((z_pred - v3) ** 2)
366
+ mad_vals_z.extend(np.abs(z_pred - v3))
367
+
368
+ # # Plot scatter plot with the specified color and label
369
+ ax.scatter(v1, v2, v3, color=color, label=f"({v1}, {v2}, {v3})", s=20.0)
370
+ ax.scatter(
371
+ x_pred,
372
+ y_pred,
373
+ z_pred,
374
+ color=color,
375
+ )
376
+
377
+ mse_x = np.mean(mse_vals_x)
378
+ mad_x = np.mean(mad_vals_x)
379
+ mse_y = np.mean(mse_vals_y)
380
+ mad_y = np.mean(mad_vals_y)
381
+ mse_z = np.mean(mse_vals_z)
382
+ mad_z = np.mean(mad_vals_z)
383
+
384
+ logger.info(f"MSE {context_col[0]}: {mse_x}")
385
+ logger.info(f"MAD {context_col[0]}: {mad_x}")
386
+ logger.info(f"MSE {context_col[1]}: {mse_y}")
387
+ logger.info(f"MAD {context_col[1]}: {mad_y}")
388
+ logger.info(f"MSE {context_col[2]}: {mse_z}")
389
+ logger.info(f"MAD {context_col[2]}: {mad_z}")
390
+
391
+ file_path = os.path.join(save_path, "metrics.txt")
392
+
393
+ with open(file_path, "w") as f:
394
+ f.write(f"MSE {context_col[0]}: {mse_x} \n")
395
+ f.write(f"MAD {context_col[0]}: {mad_x} \n")
396
+
397
+ f.write(f"MSE {context_col[1]}: {mse_y} \n")
398
+ f.write(f"MAD {context_col[1]}: {mad_y} \n")
399
+
400
+ f.write(f"MSE {context_col[2]}: {mse_z} \n")
401
+ f.write(f"MAD {context_col[2]}: {mad_z} \n")
402
+
403
+ ax.set_xlabel(COL_TO_DISPLAY_NAME[context_col[0]])
404
+ ax.set_ylabel(COL_TO_DISPLAY_NAME[context_col[1]])
405
+ ax.set_zlabel(COL_TO_DISPLAY_NAME[context_col[2]])
406
+ # plt.legend(
407
+ # bbox_to_anchor=(1.0, 0.5),
408
+ # loc="center right",
409
+ # bbox_transform=plt.gcf().transFigure,
410
+ # )
411
+ # plt.subplots_adjust(left=0.05, bottom=0.1, right=0.8)
412
+ plt.legend(
413
+ bbox_to_anchor=(1.035, 0.5),
414
+ loc="center right",
415
+ bbox_transform=plt.gcf().transFigure,
416
+ )
417
+ plt.subplots_adjust(left=0.05, bottom=0.1, right=0.775)
418
+
419
+ plt.title("Multi Property Distribution of Generated Molecules")
420
+ out_path = os.path.join(save_path, "graph.png")
421
+ print(f"Saved to {out_path}")
422
+ plt.savefig(out_path)
423
+ plt.clf()
424
+
425
+ return save_path
426
+
427
+
428
+ def calc_context_from_smiles(generated_smiles, con_col):
429
+ if con_col == "mol_weight":
430
+ predicted_context = calcContextMolWeight(generated_smiles)
431
+ elif con_col == "logp":
432
+ predicted_context = calcContextLogP(generated_smiles)
433
+ elif con_col == "sascore":
434
+ predicted_context = calcContextSAScore(generated_smiles)
435
+ elif con_col == "energy":
436
+ # TODO: Change to something better
437
+ predicted_context = calcContextEnergy(generated_smiles)
438
+ return predicted_context
439
+
440
+
441
+ def plot_unconditional(
442
+ out_path: str = os.getcwd(),
443
+ smiles: List[str] = [],
444
+ temperature: float = 0.8,
445
+ cmp_context_dict: Union[Dict[str, np.array], None] = None,
446
+ context_cols: List[str] = ["logp", "sascore", "mol_weight"],
447
+ ):
448
+ out_path = os.path.join(out_path, "unconditional")
449
+ os.makedirs(out_path, exist_ok=True)
450
+
451
+ for c in context_cols:
452
+ plt.clf()
453
+
454
+ context_cal = calc_context_from_smiles(smiles, c)
455
+
456
+ if cmp_context_dict is not None:
457
+ sns.histplot(
458
+ cmp_context_dict[c],
459
+ stat="density",
460
+ label="Dataset Distribution",
461
+ alpha=0.75,
462
+ color="blue",
463
+ )
464
+ sns.histplot(
465
+ context_cal,
466
+ stat="density",
467
+ label="Generated Molecules Distribution",
468
+ alpha=0.5,
469
+ color="orange",
470
+ )
471
+
472
+ if c == "logp":
473
+ plt.xlim((-6, 8))
474
+ else:
475
+ plt.xlim((0, 10))
476
+
477
+ plt.xlabel(COL_TO_DISPLAY_NAME[c])
478
+ plt.title(
479
+ f"Unconditional Distribution {COL_TO_DISPLAY_NAME[c]} \nwith Temperature {temperature}"
480
+ )
481
+ plt.legend()
482
+
483
+ out_file = os.path.join(out_path, f"unc_{c}_temp={temperature}.png")
484
+ plt.savefig(out_file)
485
+ logger.info(f"Saved Unconditional to {out_file}")
486
+
487
+
488
+ def novelty(gen, train):
489
+ gen_smiles_set = set(gen) - {None}
490
+ train_set = set(train)
491
+ return len(gen_smiles_set - train_set) / len(gen_smiles_set)
492
+
493
+
494
+ def unique_at(gen, k=1000):
495
+ gen = gen[:k]
496
+
497
+ return len(set(gen)) / len(gen)
498
+
499
+
500
+ def check_metrics(generated_smiles: List[str], dataset_smiles: List[str]):
501
+ len_before = len(generated_smiles)
502
+ generated_smiles = [g for g in generated_smiles if g is not None]
503
+ len_after = len(generated_smiles)
504
+
505
+ novel = novelty(generated_smiles, dataset_smiles)
506
+ unique_at_1k = unique_at(generated_smiles, k=1000)
507
+ unique_at_10k = unique_at(generated_smiles, k=10000)
508
+ return dict(
509
+ novelty=novel,
510
+ unique_at_1k=unique_at_1k,
511
+ unique_at_10k=unique_at_10k,
512
+ validity=len_after / float(len_before),
513
+ )
preprocess_dataset.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import pickle
5
+ import random
6
+ from functools import partial
7
+
8
+ import pandas as pd
9
+ import numpy as np
10
+ import requests
11
+ import torch
12
+ import torch.distributed as dist
13
+ from tqdm import tqdm
14
+ import multiprocessing
15
+ from multiprocessing import Pool
16
+ from fragment_creator import BaseFragmentCreator, BricksFragmentCreator, Fragment
17
+ from tokenizer import SmilesTokenizer
18
+ from torch.utils.data.distributed import DistributedSampler
19
+ from rdkit import Chem
20
+ from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
21
+ from tqdm.contrib.concurrent import process_map, thread_map
22
+ from typing import List
23
+ import swifter
24
+
25
+ DATA_CACHE_DIR = "data"
26
+
27
+
28
+ def _tokenize_smiles(
29
+ smi: List[str],
30
+ tokenizer: SmilesTokenizer = None,
31
+ max_smiles_len=256,
32
+ log_output=True,
33
+ ):
34
+ # try:
35
+ tokens = tokenizer.encode(smi)
36
+ if len(tokens) > max_smiles_len:
37
+ if log_output:
38
+ print(f"Removing to long {smi} with smiles len of {len(tokens)} ")
39
+ return None
40
+
41
+ return tokens
42
+
43
+ # except Exception as e:
44
+ # print(e)
45
+ # return None
46
+
47
+
48
+ def _tokenize_scaffolds(smi: str, tokenizer=None, max_smiles_len=256, log_output=True):
49
+ # try:
50
+
51
+ smi = MurckoScaffoldSmiles(smi)
52
+ tokens = tokenizer.encode(smi)
53
+ tokens = tokens[1:-1] # remove [SEP] and [CLS] tokens
54
+ if len(tokens) > max_smiles_len:
55
+ if log_output:
56
+ print(f"Removing to long {smi} with smiles len of {len(tokens)} ")
57
+ return None
58
+
59
+ return tokens
60
+
61
+ # except Exception as e:
62
+ # print(e)
63
+ # return None
64
+
65
+
66
+ def pad_batch(src, pad_idx):
67
+ max_len = max([len(d) for d in src])
68
+ # src = [d["src_input_ids"] for d in data]
69
+ padded_src = np.ones([len(src), max_len]) * pad_idx
70
+
71
+ for i, j in enumerate(src):
72
+ padded_src[i][0 : len(j)] = j
73
+
74
+ # try to predict the next token from the previouse tokens
75
+ # essentially reconstructing the src sentence from the embeddings and the previous sentence
76
+ padded_src = padded_src.T
77
+ return padded_src
78
+
79
+
80
+ def pretokenize(
81
+ data_file=os.path.join(
82
+ DATA_CACHE_DIR, "FULL_combined_zinc_pubchemqc_qm9_pc9_reddb_chembl.parquet"
83
+ ),
84
+ tokenizer=SmilesTokenizer(),
85
+ limit=None,
86
+ context=["logp", "sascore", "mol_weight"],
87
+ out_name: str = "processed_dataset",
88
+ remove_nan_context_rows: bool = False,
89
+ ):
90
+ df = pd.read_parquet(data_file)
91
+
92
+ if limit is not None:
93
+ # smiles_list = df.smiles[:limit]
94
+ df = df.sample(n=limit) # df[:limit]
95
+ # NOTE: Set here if necessary, but for memory efficiency not duplicating millions of smiles
96
+ # smiles_list = df.smiles
97
+ else:
98
+ # shuffle the rows
99
+ df = df.sample(frac=1.0)
100
+
101
+ cpu_count = (
102
+ multiprocessing.cpu_count()
103
+ ) # min(int(multiprocessing.cpu_count() * 0.8), 8)
104
+ print(f"Running on {cpu_count} CPUs ")
105
+
106
+ tqdm.pandas()
107
+
108
+ df["scaffolds"] = df["smiles"].progress_map(lambda s: None if "." in s else s)
109
+ df["smiles"] = df["scaffolds"].copy()
110
+ orig_len = len(df)
111
+ if context is not None:
112
+ if df.get("origin") is not None:
113
+ origins = df["origin"].unique()
114
+ origin_dics = {}
115
+ for i, o in enumerate(origins):
116
+ df.loc[df["origin"] == o, "origin"] = i
117
+ origin_dics[o] = i
118
+ df["origin"] = df["origin"].astype(float)
119
+ with open(
120
+ os.path.join(
121
+ DATA_CACHE_DIR, os.path.basename(data_file) + "_origins.json"
122
+ ),
123
+ "w",
124
+ ) as f:
125
+ json.dump(origin_dics, f)
126
+
127
+ mask = (
128
+ ~df["smiles"].isna()
129
+ & (
130
+ (~df[context].isna()).all(axis=1)
131
+ if remove_nan_context_rows
132
+ else np.ones(len(df["smiles"]), dtype=bool)
133
+ )
134
+ & ~df["scaffolds"].isna()
135
+ )
136
+ else:
137
+ mask = ~df["smiles"].isna()
138
+ error_count = np.count_nonzero(~mask)
139
+ df = df[mask]
140
+ # print("HELLO")
141
+ # print("***"*10)
142
+
143
+ # tokenizer.batch_encode_plus()
144
+
145
+ # df["scaffolds"] = df["scaffolds"].swifter.apply(
146
+ # partial(_tokenize_scaffolds, tokenizer=tokenizer, log_output=False)
147
+ # )
148
+ # df["scaffolds"] = df["scaffolds"].swifter.apply(
149
+ # partial(_tokenize_scaffolds, tokenizer=tokenizer, log_output=False)
150
+ # )
151
+ df["tokens"] = df["smiles"].swifter.apply(
152
+ partial(_tokenize_smiles, tokenizer=tokenizer, log_output=False)
153
+ )
154
+ df["scaffolds"] = df["tokens"].copy()
155
+
156
+ mask = ~df["tokens"].isna() & ~df["scaffolds"].isna()
157
+ df = df[mask]
158
+ error_count += np.count_nonzero(~mask)
159
+
160
+ # Shuffle the data
161
+ df = df.sample(frac=1).reset_index(drop=True)
162
+ # with Pool(cpu_count) as p:
163
+ # df["scaffolds"] = list(
164
+
165
+ # p.map(partial( _tokenize_scaffolds ,tokenizer=tokenizer, log_output=False), tqdm(df.smiles.to_numpy(),total=len(df)), chunksize=1000),
166
+
167
+ # )
168
+
169
+ # df["smiles"] = list(
170
+ # p.map(partial( _tokenize_smiles ,tokenizer=tokenizer, log_output=False), tqdm(df.smiles.to_numpy(),total=len(df)), chunksize=1000),
171
+ # )
172
+
173
+ if context is not None:
174
+ context_list = df[context].to_numpy()
175
+ context_dict = {k: context_list[:, i] for i, k in enumerate(context)}
176
+ else:
177
+ context_dict = {}
178
+
179
+ print(f"Error count: {error_count} / {orig_len} = {error_count/orig_len}")
180
+
181
+ cache_path = os.path.join(os.path.dirname(__file__), ".cache")
182
+ os.makedirs(cache_path, exist_ok=True)
183
+ out_path = os.path.join(cache_path, f"{out_name}_{limit}.pkl")
184
+ with open(out_path, "wb") as f:
185
+ pickle.dump(
186
+ {
187
+ "tokens": df["tokens"].tolist(),
188
+ "smiles": df["smiles"].tolist(),
189
+ "scaf": df["scaffolds"].tolist(),
190
+ **context_dict,
191
+ },
192
+ f,
193
+ )
194
+ print(f"Saved to {out_path}")
195
+ print("Done.")
196
+
197
+
198
+ class PretokDataset(torch.utils.data.Dataset):
199
+ """Loads pretokenized example from disk and returns them as PyTorch tensors."""
200
+
201
+ def __init__(self, split, pad_token_id, dataset="processed_dataset.pkl"):
202
+ super().__init__()
203
+ self.split = split
204
+ self.dataset = dataset
205
+ self.pad_token_id = pad_token_id
206
+ cache_path = os.path.join(os.path.dirname(__file__), ".cache")
207
+ with open(os.path.join(cache_path, self.dataset), "rb") as f:
208
+ self.data_dict = pickle.load(f)
209
+
210
+ # split out 10% of the data for validation
211
+ split_ix = int(len(self.data_dict["tokens"]) * 0.9)
212
+ if self.split == "train":
213
+ self.data_dict = {k: self.data_dict[k][:split_ix] for k in self.data_dict}
214
+ elif self.split == "val":
215
+ self.data_dict = {k: self.data_dict[k][split_ix:] for k in self.data_dict}
216
+ else:
217
+ raise RuntimeError(f"Could not find split for: self.split={self.split}")
218
+
219
+ def __len__(self):
220
+ return len(self.data_dict["tokens"])
221
+
222
+ def __getitem__(self, idx):
223
+ m = self.data_dict
224
+
225
+ start = idx
226
+ end = idx + 1
227
+
228
+ # calling .astype will copy the data into a new numpy array, now in RAM
229
+ padded_tokens = pad_batch(m["tokens"][start:end], self.pad_token_id)
230
+ chunk = torch.from_numpy((padded_tokens).astype(np.int64))
231
+
232
+ padded_scaffolds = torch.from_numpy(
233
+ pad_batch(m["scaf"][start:end], self.pad_token_id).astype(np.int64)
234
+ )
235
+
236
+ item = {
237
+ "seq": chunk,
238
+ "scaf": padded_scaffolds,
239
+ "smiles": m["smiles"][start:end],
240
+ **{
241
+ k: torch.tensor(m[k][start:end], dtype=torch.float32)
242
+ for k in m
243
+ if k != "scaf" and k != "tokens" and k != "smiles"
244
+ },
245
+ }
246
+
247
+ return item
248
+
249
+
250
+ def padding_collate_fn(
251
+ data, tokenizer: SmilesTokenizer, fragment_creator: BaseFragmentCreator
252
+ ):
253
+ # data = list of dicts
254
+ pad_idx = tokenizer.pad_token_id
255
+
256
+ src = [d["seq"] for d in data]
257
+
258
+ max_len = max([len(d) for d in src])
259
+ padded_src = np.ones([len(src), max_len]) * pad_idx
260
+ for i, j in enumerate(src):
261
+ padded_src[i][0 : len(j)] = j.ravel()
262
+
263
+ if fragment_creator is None:
264
+ smiles_context = [d["scaf"] for d in data]
265
+ else:
266
+ # Remove start and end token after tokenization with [1:-1 ]
267
+ smiles_context = []
268
+ for d in data:
269
+ s = d["smiles"][0]
270
+ tokens = d["seq"]
271
+ frag = fragment_creator.create_fragment(Fragment(smiles=s, tokens=tokens))
272
+ if frag.tokens is not None:
273
+ smiles_context.append(frag.tokens)
274
+ else:
275
+ smiles_context.append(
276
+ torch.tensor(
277
+ tokenizer.encode(frag.smiles)[1:-1],
278
+ dtype=torch.long,
279
+ device=tokens.device,
280
+ )
281
+ )
282
+
283
+ max_len_ctx = max([len(d) for d in smiles_context])
284
+ padded_smiles_context = np.ones([len(smiles_context), max_len_ctx]) * pad_idx
285
+ for i, j in enumerate(smiles_context):
286
+ padded_smiles_context[i][0 : len(j)] = j.ravel()
287
+ # try to predict the next token from the previouse tokens
288
+ # essentially reconstructing the src sentence from the embeddings and the previous sentence
289
+ padded_src = padded_src.T
290
+
291
+ original_context_keys = [
292
+ k for k in data[0].keys() if k != "seq" and k != "scaf" and k != "smiles"
293
+ ]
294
+ context_out_dict = {k: [] for k in original_context_keys}
295
+
296
+ for k in original_context_keys:
297
+ val_list = []
298
+ for d in data:
299
+ val_list.append(d[k])
300
+
301
+ context_out_dict[k] = torch.concat(val_list, dim=0)
302
+
303
+ return {
304
+ "src": torch.tensor(padded_src, dtype=torch.long), # for (seq_len, batch_size)
305
+ "fragment": torch.tensor(padded_smiles_context.T, dtype=torch.long),
306
+ "context": context_out_dict,
307
+ }
308
+
309
+
310
+ class SmilesTask:
311
+ @staticmethod
312
+ def iter_batches(
313
+ split,
314
+ batch_size,
315
+ device,
316
+ context_keys: List[str],
317
+ num_workers=0,
318
+ dataset="processed_dataset.pkl",
319
+ fragment_creator: BaseFragmentCreator = BricksFragmentCreator(),
320
+ ):
321
+ tokenizer = SmilesTokenizer()
322
+ ds = PretokDataset(split, tokenizer.pad_token_id, dataset=dataset)
323
+ is_ddp = int(os.environ.get("RANK", -1)) != -1
324
+ dl = torch.utils.data.DataLoader(
325
+ ds,
326
+ batch_size=batch_size,
327
+ pin_memory=True,
328
+ num_workers=num_workers,
329
+ shuffle=False,
330
+ sampler=DistributedSampler(ds) if is_ddp else None,
331
+ collate_fn=lambda batch: padding_collate_fn(
332
+ batch, tokenizer, fragment_creator
333
+ ),
334
+ )
335
+
336
+ for data in dl:
337
+ data["src"] = data["src"].to(device, non_blocking=True)
338
+ data["tgt"] = data["src"].to(device, non_blocking=True)
339
+
340
+ data["src"] = data["src"][:-1, :].T # batch_size, seq_len
341
+ data["tgt"] = data["tgt"][1:, :].T # batch_size, seq_len
342
+
343
+ data["fragment"] = (
344
+ data["fragment"].to(device, non_blocking=True).T
345
+ ) # batch_size, seq_len
346
+ keys = list(data["context"].keys())
347
+ for d in keys:
348
+ if d not in context_keys:
349
+ del data["context"][d]
350
+ else:
351
+ data["context"][d] = data["context"][d].to(
352
+ device, non_blocking=True
353
+ )
354
+
355
+ yield data
356
+
357
+
358
+ if __name__ == "__main__":
359
+
360
+ pretokenize(
361
+ data_file=os.path.join(
362
+ DATA_CACHE_DIR,
363
+ "OrganiX13.parquet",
364
+ ),
365
+ limit=None, # Set how many molecules should be processed, if None all molecules will be processed,
366
+ context=["logp", "sascore", "mol_weight"],
367
+ out_name="processed_dataset",
368
+ remove_nan_context_rows=False,
369
+ )
370
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.5
2
+ pytest==7.4.0
3
+ Requests==2.31.0
4
+ sentencepiece==0.1.99
5
+ tiktoken==0.3.3
6
+ torch==2.0.1
7
+ tqdm==4.64.1
8
+ wandb==0.15.5
sample.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import nullcontext
3
+ import sys
4
+ import time
5
+ import pandas as pd
6
+ import torch
7
+ from tqdm.auto import tqdm
8
+
9
+ # from tqdm.notebook import tqdm
10
+ from model import Transformer
11
+ from plot_utils import (
12
+ check_metrics,
13
+ plot_1D_condition,
14
+ plot_2D_condition,
15
+ plot_3D_condition,
16
+ plot_unconditional,
17
+ )
18
+ from tokenizer import SmilesTokenizer
19
+ import numpy as np
20
+ from typing import Dict, List, Tuple, Union
21
+ import re
22
+
23
+ from rdkit import Chem
24
+ from rdkit import DataStructs
25
+ from rdkit.Chem.Fingerprints import FingerprintMols
26
+
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class Sampler:
33
+ def __init__(
34
+ self,
35
+ load_path: str,
36
+ device: str = "cpu",
37
+ seed: int = 1337,
38
+ dtype: str = "float16",
39
+ compile: bool = True,
40
+ quantize: bool = False,
41
+ ) -> None:
42
+ self.load_path = load_path
43
+ self.device = device
44
+ self.dtype = dtype
45
+ self.compile = compile
46
+ self.quantize = quantize
47
+ self.seed = seed
48
+ self._init_model()
49
+
50
+ def _init_model(self):
51
+ np.random.seed(self.seed)
52
+ torch.cuda.manual_seed(self.seed)
53
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
54
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
55
+ self.device_type = (
56
+ "cuda" if "cuda" in self.device else "cpu"
57
+ ) # for later use in torch.autocast
58
+ ptdtype = {
59
+ "float32": torch.float32,
60
+ "bfloat16": torch.bfloat16,
61
+ "float16": torch.float16,
62
+ }[self.dtype]
63
+ self.ptdtype = ptdtype
64
+
65
+ self.ctx = self._autocast()
66
+ # init from a model saved in a specific directory
67
+ # ckpt_path = os.path.join(out_dir, "ckpt_full_dim=256.pt")
68
+ self.model = Transformer.load(self.load_path, device=self.device)
69
+
70
+ self.model.eval()
71
+ if self.quantize:
72
+ raise NotImplementedError("Not properly implemented for CPU / GPU")
73
+ self.model = torch.ao.quantization.quantize_dynamic(
74
+ self.model, # the original model
75
+ {torch.nn.Linear}, # a set of layers to dynamically quantize
76
+ dtype=torch.qint8,
77
+ )
78
+
79
+ if self.compile:
80
+ logger.info("Compiling the model...")
81
+ self.model = torch.compile(self.model) # requires PyTorch 2.0 (optional)
82
+
83
+ self.model = self.model.to(self.device)
84
+ # load the tokenizer
85
+ self.tokenizer = SmilesTokenizer()
86
+
87
+ def get_context(
88
+ self,
89
+ context_col: List[str],
90
+ context_smi: str,
91
+ num_examples: int = 50,
92
+ ):
93
+ """
94
+ Returns a dictionary in the form of
95
+ {
96
+ "fragment": torch.tensor,
97
+ "context": {
98
+ "logp": torch.tensor,
99
+ "sascore": torch.tensor,
100
+ "mol_weight": torch.tensor
101
+ }
102
+ }
103
+
104
+
105
+ When context_smi is set to a string, then the "fragment" field is populated.
106
+ All of the properties listed in the context_col list is set to the keys and the values are set to a resonable range for each property.
107
+
108
+ num_examples indicates how many values are sampled for each property.
109
+ """
110
+ output_dict = {"context": {}, "fragment": None}
111
+
112
+ if context_smi is not None:
113
+ logger.debug(
114
+ f"context_smiles: {context_smi}",
115
+ )
116
+ # NOTE: Remove beginning [CLS] and end token [SEP]
117
+ incorporate_selfie = self.tokenizer.encode(context_smi)[1:-1]
118
+
119
+ context = torch.tensor(
120
+ [incorporate_selfie] * num_examples,
121
+ dtype=torch.long,
122
+ device=self.device,
123
+ )
124
+
125
+ output_dict["fragment"] = context
126
+
127
+ if context_col is None:
128
+ return output_dict
129
+
130
+ if "logp" in context_col:
131
+ # context = 0.5 * torch.randint(
132
+ # -8, 14, (num_examples,), device=self.device, dtype=torch.float
133
+ # )
134
+ # context = 0.5 * torch.randint(
135
+ # -6, 6, (num_examples, 1), device=device, dtype=torch.float
136
+ # )
137
+ context = torch.tensor(
138
+ np.random.choice([-2, 0, 2], (num_examples,)),
139
+ device=self.device,
140
+ dtype=self.ptdtype,
141
+ )
142
+ # context = 2.0 * torch.ones(
143
+ # (num_examples,1), device=device, dtype=torch.float
144
+ # )
145
+ # context = -2.0*torch.ones((num_examples,2),device=device,dtype=torch.float)
146
+ # context, _ = torch.sort(context, 0)
147
+ output_dict["context"]["logp"] = context
148
+
149
+ if "energy" in context_col:
150
+ context = 0.1 * torch.randint(
151
+ -15, 15, (num_examples,), device=self.device, dtype=torch.float
152
+ )
153
+ # context = -2.0*torch.ones((num_examples,2),device=device,dtype=torch.float)
154
+ context, _ = torch.sort(context, 0)
155
+ output_dict["context"]["energy"] = context
156
+
157
+ if "sascore" in context_col:
158
+ # context = 0.5 * torch.randint(
159
+ # 2, 20, (num_examples, ), device=self.device, dtype=torch.float
160
+ # )
161
+ context = torch.tensor(
162
+ np.random.choice([2, 3, 4], (num_examples,)),
163
+ device=self.device,
164
+ dtype=torch.float,
165
+ )
166
+ # context = 0.5 * torch.randint(
167
+ # 4, 8, (num_examples, 1), device=device, dtype=torch.float
168
+ # )
169
+ # context = 2.0*torch.ones((num_examples,1),device=device,dtype=torch.float)
170
+ # context, _ = torch.sort(context, 0)
171
+ output_dict["context"]["sascore"] = context
172
+
173
+ if "mol_weight" in context_col:
174
+ # context = 0.5 * torch.randint(
175
+ # 2, 20, (num_examples,), device=self.device, dtype=torch.float
176
+ # )
177
+ context = torch.tensor(
178
+ np.random.choice([2.0, 3.0, 4.0], (num_examples,)),
179
+ device=self.device,
180
+ dtype=torch.float,
181
+ )
182
+
183
+ # context = 0.5 * torch.randint(
184
+ # 2, 20, (num_examples, 1), device=device, dtype=torch.float
185
+ # )
186
+ # context = 2.5*torch.ones((num_examples,1),device=device,dtype=torch.float)
187
+ # context, _ = torch.sort(context, 0)
188
+ output_dict["context"]["mol_weight"] = context
189
+
190
+ return output_dict
191
+
192
+ def _autocast(self):
193
+ if "cuda" in self.device:
194
+ if self.dtype == "bfloat16" and torch.cuda.is_bf16_supported():
195
+ return torch.cuda.amp.autocast(dtype=torch.bfloat16)
196
+ elif self.dtype == "float16":
197
+ return torch.cuda.amp.autocast(dtype=torch.float16)
198
+ else:
199
+ return torch.cuda.amp.autocast(dtype=torch.float32)
200
+ else: # cpu
201
+ return nullcontext()
202
+
203
+ @torch.no_grad()
204
+ def generate(
205
+ self,
206
+ context_cols: Union[List[str], None, Dict[str, torch.Tensor]] = None,
207
+ context_smi: Union[str, None] = None,
208
+ start_smiles: Union[str, None] = None,
209
+ num_samples: int = 50,
210
+ max_new_tokens: int = 256,
211
+ temperature: float = 1.0,
212
+ top_k: Union[int, None] = None,
213
+ return_context: bool = False,
214
+ total_gen_steps: int = 1,
215
+ use_kv_cache: bool = False,
216
+ ) -> Union[List[str], Tuple[List[str], List[float]]]:
217
+ """
218
+ Generates a list of SMILES. With the default options it would generate them unconditionally.
219
+ Params:
220
+ - context_cols : When a list the context is randomly sampled from the get_context method, when given a dictionary the
221
+ context values are taken from the dictionary instead.
222
+ - context_smi : Further conditioning by the usage of a molecular fragment
223
+ . start_smiles : Can be used to start the SMILES with a specific string, the model then generates the next tokens including that start sequence.
224
+ - num_samples : Controlls how many SMILES in total will be generated be the model.
225
+ - max_new_tokens : Controlls the maximum length of each SMILES (in tokens) that is generated.
226
+ - temperature: Controlls the randomness of the model. A temperature = 1.0 means it is the trained distribution. A temperature < 1 is more deterministic and temperature > 1 is more random
227
+ - top_k : Clamps the probability distribution to the top k tokens. From these the next token is then sampled from.
228
+ - return_context : Whether the context that was given to the model should be returned.
229
+ - total_gen_steps : In how many sub steps the generation should be split up to. Useful when generation 10k + SMILES and wanting to chunk these into for example 10 * 1k generations with total_gen_steps = 10.
230
+ - use_kv_cache: Runs the generation using kv-caching. It is faster, but takes more memory.
231
+ """
232
+
233
+ with self.ctx:
234
+ gens_per_step = num_samples // total_gen_steps
235
+
236
+ logger.debug(f"Gens per Step: {gens_per_step}")
237
+ context = None # {"context": None, "fragment" : None}
238
+ out_smiles = []
239
+ with tqdm(total=total_gen_steps, desc="Batch") as pbar:
240
+ for i in range(total_gen_steps):
241
+ if isinstance(context_cols, dict):
242
+ # TODO: Test if same length
243
+ cd = {
244
+ c: context_cols[c][
245
+ i * gens_per_step : (i + 1) * gens_per_step
246
+ ]
247
+ for c in context_cols.keys()
248
+ }
249
+
250
+ context_dict = {"context": cd, "fragment": None}
251
+ if context_smi is not None:
252
+ logger.debug(
253
+ f"context_smiles: {context_smi}",
254
+ )
255
+ # NOTE: Remove beginning [CLS] and end token [SEP]
256
+ incorporate_selfie = self.tokenizer.encode(context_smi)[
257
+ 1:-1
258
+ ]
259
+
260
+ context_tensor = torch.tensor(
261
+ [incorporate_selfie] * gens_per_step,
262
+ dtype=torch.long,
263
+ device=self.device,
264
+ )
265
+
266
+ context_dict["fragment"] = context_tensor
267
+ context_cols = list(context_cols.keys())
268
+
269
+ else:
270
+ context_dict = self.get_context(
271
+ context_cols, context_smi, num_examples=gens_per_step
272
+ )
273
+
274
+ # for k in range(num_samples):
275
+ y = self.model.generate(
276
+ self.tokenizer,
277
+ context=context_dict["context"],
278
+ fragments=context_dict["fragment"],
279
+ start_smiles=start_smiles,
280
+ num_gen=gens_per_step,
281
+ temperature=temperature,
282
+ top_k=top_k,
283
+ max_length=max_new_tokens,
284
+ device=self.device,
285
+ cache_kv=use_kv_cache,
286
+ )
287
+
288
+ new_context = {k: [] for k in context_dict["context"]}
289
+ for i, sample in enumerate(y):
290
+ # print(sample)
291
+ mol = Chem.MolFromSmiles(sample)
292
+ if mol is not None:
293
+ out_smiles.append(sample)
294
+ for k in new_context:
295
+ new_context[k].append(
296
+ context_dict["context"][k][i].unsqueeze(-1)
297
+ )
298
+
299
+ for k in new_context:
300
+ new_context[k] = torch.concat(new_context[k], dim=0)
301
+
302
+ if context is None:
303
+ context = new_context
304
+ else:
305
+ for k in context:
306
+ context[k] = torch.concat(
307
+ [context[k], new_context[k]], dim=0
308
+ )
309
+
310
+ pbar.update(1)
311
+
312
+ logger.info(
313
+ f"Number valid generated: {len(out_smiles) / num_samples * 100} %"
314
+ )
315
+ logger.info("---------------")
316
+
317
+ if return_context:
318
+ return (out_smiles, context)
319
+
320
+ else:
321
+ return out_smiles
322
+
323
+ @torch.no_grad()
324
+ def generate_with_evaluation(
325
+ self,
326
+ context_cols: Union[List[str], None] = None,
327
+ context_smi: Union[str, None] = None,
328
+ start_smiles: Union[str, None] = None,
329
+ num_samples: int = 50,
330
+ max_new_tokens: int = 256,
331
+ temperature: float = 1.0,
332
+ top_k: Union[int, None] = None,
333
+ cmp_context_dict: Union[Dict[str, torch.Tensor], None] = None,
334
+ total_gen_steps: int = 1,
335
+ use_kv_cache: bool = False,
336
+ ):
337
+ out_smiles, new_context = self.generate(
338
+ context_cols=context_cols,
339
+ context_smi=context_smi,
340
+ start_smiles=start_smiles,
341
+ num_samples=num_samples,
342
+ max_new_tokens=max_new_tokens,
343
+ temperature=temperature,
344
+ top_k=top_k,
345
+ return_context=True,
346
+ total_gen_steps=total_gen_steps,
347
+ use_kv_cache=use_kv_cache,
348
+ )
349
+
350
+ out_dir = os.path.dirname(self.load_path)
351
+
352
+ if context_cols is not None:
353
+ if len(context_cols) == 1:
354
+ plot_1D_condition(
355
+ context_cols,
356
+ os.path.join(out_dir, "plots"),
357
+ new_context,
358
+ out_smiles,
359
+ temperature,
360
+ cmp_context_dict,
361
+ context_scaler=None,
362
+ )
363
+
364
+ elif len(context_cols) == 2:
365
+ plot_2D_condition(
366
+ context_cols,
367
+ os.path.join(out_dir, "plots"),
368
+ new_context,
369
+ out_smiles,
370
+ temperature,
371
+ label=context_smi,
372
+ )
373
+
374
+ elif len(context_cols) == 3:
375
+ plot_3D_condition(
376
+ context_cols,
377
+ os.path.join(out_dir, "plots"),
378
+ new_context,
379
+ out_smiles,
380
+ temperature,
381
+ )
382
+
383
+ else:
384
+ raise NotImplementedError(
385
+ "Currently not implemented for len(context_col) > 3"
386
+ )
387
+
388
+ else:
389
+ # Unconditional Case
390
+ plot_unconditional(
391
+ out_path=os.path.join(out_dir, "plots"),
392
+ smiles=out_smiles,
393
+ temperature=temperature,
394
+ cmp_context_dict=cmp_context_dict,
395
+ )
396
+
397
+ if context_smi is not None:
398
+ pattern = r"\[\d+\*\]"
399
+ # replace [14*] etc
400
+ context_smi = re.sub(pattern, "", context_smi)
401
+
402
+ context_mol = Chem.MolFromSmiles(context_smi)
403
+ context_smarts = Chem.MolToSmarts(context_mol)
404
+
405
+ pattern = r"(?<!\[)([:-=#])(?!\])(?![^\[]*?\])"
406
+
407
+ context_smarts = re.sub(pattern, "~", context_smarts)
408
+ logger.info(f"context_smarts {context_smarts}")
409
+ out_mols = [Chem.MolFromSmiles(smi) for smi in out_smiles]
410
+
411
+ context_fingerprint = FingerprintMols.FingerprintMol(context_mol)
412
+ out_fingerprints = [FingerprintMols.FingerprintMol(fi) for fi in out_mols]
413
+ all_sim = []
414
+ all_sub = []
415
+ for out_fing, out_mol in zip(out_fingerprints, out_mols):
416
+ similarity = DataStructs.TanimotoSimilarity(
417
+ context_fingerprint, out_fing
418
+ )
419
+
420
+ has_sub = out_mol.HasSubstructMatch(Chem.MolFromSmarts(context_smarts))
421
+ all_sub.append(has_sub)
422
+ all_sim.append(similarity)
423
+
424
+ # print(similarity,has_sub)
425
+ logger.info(f"Mean sim {np.mean(all_sim)}")
426
+ logger.info(
427
+ f"Has Sub: {np.count_nonzero(all_sub)} or {round(np.count_nonzero(all_sub) / len(all_sub) * 100, 4)} %"
428
+ )
429
+
430
+ return out_smiles, new_context
431
+
432
+
433
+ if __name__ == "__main__":
434
+ import argparse
435
+ import rdkit.rdBase as rkrb
436
+ import rdkit.RDLogger as rkl
437
+
438
+ logger = rkl.logger()
439
+ logger.setLevel(rkl.ERROR)
440
+ rkrb.DisableLog("rdApp.error")
441
+
442
+ torch.set_num_threads(8)
443
+ logging.basicConfig(level=logging.INFO)
444
+ logger = logging.getLogger(__name__)
445
+
446
+ parser = argparse.ArgumentParser(
447
+ description="Generate SMILES strings using a trained model."
448
+ )
449
+ # parser.add_argument('--context_cols', type=str, nargs='+', default=None)
450
+ parser.add_argument(
451
+ "--context_cols",
452
+ type=str,
453
+ nargs="+",
454
+ default=None,
455
+ help="The given conditions are sampled from a fixed interval and given to the modeĺ.",
456
+ )
457
+ parser.add_argument(
458
+ "--context_smi",
459
+ type=str,
460
+ default=None,
461
+ help="This SMILES is given as context to the model and should be integrated in the generated molecules.",
462
+ )
463
+ parser.add_argument(
464
+ "--start_smiles",
465
+ type=str,
466
+ default=None,
467
+ help="This SMILES is placed at the front of each sample, from which on the generation continues.",
468
+ )
469
+ parser.add_argument(
470
+ "--ckpt_path",
471
+ type=str,
472
+ default=os.path.join(os.path.dirname(__file__), "out", "llama2-M-Full-RSS.pt"),
473
+ help="Which model should be used in the generation",
474
+ )
475
+ parser.add_argument(
476
+ "--num_samples",
477
+ type=int,
478
+ default=50,
479
+ help="Controls how many samples should be generated",
480
+ )
481
+ parser.add_argument(
482
+ "--num_samples_per_step",
483
+ type=int,
484
+ default=1000,
485
+ help="Works in conjunction with num_samples, by splitting the total into num_samples_per_step jobs. When num_samples > num_samples_per_step then it is split up into multiple seperate generation steps.",
486
+ )
487
+
488
+ parser.add_argument(
489
+ "--max_new_tokens",
490
+ type=int,
491
+ default=256,
492
+ help="Sets how many tokens should be generated from the model. We only trained with a max size of 256, but it is possible to generate longer molecules. However, these might be worse in quality.",
493
+ )
494
+ parser.add_argument(
495
+ "--temperature",
496
+ type=float,
497
+ default=0.8,
498
+ help="Sets the randomness of the generation - A temperature of 0 would be deterministic and a temperature of > 1 is more random.",
499
+ )
500
+ parser.add_argument(
501
+ "--top_k",
502
+ type=int,
503
+ default=None,
504
+ help="The top_k of the sampling. Per default it is None, but can be set to an integer to have a more focused generation.",
505
+ )
506
+ parser.add_argument(
507
+ "--seed",
508
+ type=int,
509
+ default=1234,
510
+ help="Random number generator seed, to make sampling consistent.",
511
+ )
512
+ parser.add_argument(
513
+ "--cmp_dataset_path",
514
+ type=str,
515
+ default=None,
516
+ help="A dataset in parquet or csv format to be used in the sample plots and to compute the metrics such as the novelty.",
517
+ )
518
+ device = "cuda" if torch.cuda.is_available() else "cpu"
519
+ parser.add_argument(
520
+ "--device",
521
+ type=str,
522
+ default=device,
523
+ help="Change the device the model and generation is run on",
524
+ )
525
+
526
+ if "cuda" in device:
527
+ # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
528
+ dtype = "float16" if torch.cuda.is_available() else "float32"
529
+ else:
530
+ dtype = "float32"
531
+
532
+ parser.add_argument(
533
+ "--dtype",
534
+ type=str,
535
+ default=dtype,
536
+ help="Change the datatype of the computation. Per default it is float32 on CPU and float16 on GPU",
537
+ )
538
+ parser.add_argument(
539
+ "--compile",
540
+ type=bool,
541
+ default=True,
542
+ help="Use torch.compile to compile the model. Only works on torch>=2.0, but should make the inference faster.",
543
+ )
544
+ parser.add_argument(
545
+ "--quantize",
546
+ type=bool,
547
+ default=False,
548
+ help="(CURRENTLY NOT WORKING) Enable quantization to in8.",
549
+ )
550
+ parser.add_argument(
551
+ "--kv_caching",
552
+ action="store_true",
553
+ default=False,
554
+ help="Makes the attention mechanism linear, because the old keys and values are cached. The drawback is higher memory consumption.",
555
+ )
556
+ args = parser.parse_args()
557
+
558
+ logger.info("Sampling with the following parameters:")
559
+ logger.info(f"Checkpoint: {args.ckpt_path}")
560
+ logger.info(f"Context columns: {args.context_cols}")
561
+ logger.info(f"Context SMILES: {args.context_smi}")
562
+ logger.info(f"Start SMILES: {args.start_smiles}")
563
+ logger.info(f"Number of samples: {args.num_samples}")
564
+ logger.info(f"Max new tokens: {args.max_new_tokens}")
565
+ logger.info(f"Temperature: {args.temperature}")
566
+ logger.info(f"Top k: {args.top_k}")
567
+ logger.info(f"Seed: {args.seed}")
568
+ logger.info(f"Device: {args.device}")
569
+ logger.info(f"Data type: {args.dtype}")
570
+ logger.info(f"Compile: {args.compile}")
571
+ logger.info(f"Comparison dataset path: {args.cmp_dataset_path}")
572
+ logger.info(f"Quantize: {args.quantize}")
573
+ logger.info(f"Key Value Caching Enabled: {args.kv_caching}")
574
+
575
+ sampler = Sampler(
576
+ load_path=os.path.join(os.path.dirname(__file__), args.ckpt_path),
577
+ device=args.device,
578
+ seed=args.seed,
579
+ dtype=args.dtype,
580
+ compile=args.compile,
581
+ quantize=args.quantize,
582
+ )
583
+
584
+ comp_context_dict = None
585
+ comp_smiles = None
586
+ if args.cmp_dataset_path is not None:
587
+ df_comp = pd.read_parquet(args.cmp_dataset_path)
588
+ df_comp = df_comp.sample(n=2_500_000)
589
+ comp_context_dict = {
590
+ c: df_comp[c].to_numpy() for c in ["logp", "sascore", "mol_weight"]
591
+ }
592
+ comp_smiles = df_comp["smiles"]
593
+
594
+ measure_time = True
595
+ start_time = time.time()
596
+ smiles, context = sampler.generate_with_evaluation(
597
+ context_cols=args.context_cols,
598
+ context_smi=args.context_smi,
599
+ start_smiles=args.start_smiles,
600
+ num_samples=args.num_samples,
601
+ max_new_tokens=args.max_new_tokens,
602
+ temperature=args.temperature,
603
+ top_k=args.top_k,
604
+ cmp_context_dict=comp_context_dict,
605
+ total_gen_steps=int(np.ceil(args.num_samples / args.num_samples_per_step)),
606
+ use_kv_cache=args.kv_caching,
607
+ )
608
+ end_time = time.time()
609
+ if measure_time:
610
+ logger.info(f"Generation took: {end_time - start_time} sec")
611
+ if comp_smiles is not None:
612
+ res_metrics = check_metrics(smiles, comp_smiles)
613
+ logger.info(f"Metrics: {res_metrics}")
614
+ logger.info("Generated Molecules:")
615
+ for s in smiles:
616
+ print(s)
tokenizer.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requriments - transformers, tokenizers
2
+ # Right now, the Smiles Tokenizer uses an exiesting vocab file from rxnfp that is fairly comprehensive and from the USPTO dataset.
3
+ # The vocab may be expanded in the near future
4
+
5
+ # Code taken from here: https://github.com/deepchem/deepchem/blob/2.4.0/deepchem/feat/smiles_tokenizer.py#L39-L282
6
+ import collections
7
+ import os
8
+ import re
9
+ import pkg_resources
10
+ from typing import List
11
+ from transformers import BertTokenizer
12
+ from logging import getLogger
13
+
14
+ logger = getLogger(__name__)
15
+ """
16
+ SMI_REGEX_PATTERN: str
17
+ SMILES regex pattern for tokenization. Designed by Schwaller et. al.
18
+
19
+ References
20
+
21
+ .. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
22
+ ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
23
+ 1572-1583 DOI: 10.1021/acscentsci.9b00576
24
+
25
+ """
26
+
27
+ SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
28
+
29
+ # add vocab_file dict
30
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
31
+
32
+
33
+ def get_default_tokenizer():
34
+ default_vocab_path = pkg_resources.resource_filename(
35
+ "deepchem", "feat/tests/vocab.txt"
36
+ )
37
+ return SmilesTokenizer(default_vocab_path)
38
+
39
+
40
+ class SmilesTokenizer(BertTokenizer):
41
+ """
42
+ Creates the SmilesTokenizer class. The tokenizer heavily inherits from the BertTokenizer
43
+ implementation found in Huggingface's transformers library. It runs a WordPiece tokenization
44
+ algorithm over SMILES strings using the tokenisation SMILES regex developed by Schwaller et. al.
45
+
46
+ Please see https://github.com/huggingface/transformers
47
+ and https://github.com/rxn4chemistry/rxnfp for more details.
48
+
49
+ Examples
50
+ --------
51
+ >>> from deepchem.feat.smiles_tokenizer import SmilesTokenizer
52
+ >>> current_dir = os.path.dirname(os.path.realpath(__file__))
53
+ >>> vocab_path = os.path.join(current_dir, 'tests/data', 'vocab.txt')
54
+ >>> tokenizer = SmilesTokenizer(vocab_path)
55
+ >>> print(tokenizer.encode("CC(=O)OC1=CC=CC=C1C(=O)O"))
56
+ [12, 16, 16, 17, 22, 19, 18, 19, 16, 20, 22, 16, 16, 22, 16, 16, 22, 16, 20, 16, 17, 22, 19, 18, 19, 13]
57
+
58
+
59
+ References
60
+ ----------
61
+ .. [1] Schwaller, Philippe; Probst, Daniel; Vaucher, Alain C.; Nair, Vishnu H; Kreutter, David;
62
+ Laino, Teodoro; et al. (2019): Mapping the Space of Chemical Reactions using Attention-Based Neural
63
+ Networks. ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.9897365.v3
64
+
65
+ Notes
66
+ ----
67
+ This class requires huggingface's transformers and tokenizers libraries to be installed.
68
+ """
69
+
70
+ vocab_files_names = VOCAB_FILES_NAMES
71
+
72
+ def __init__(
73
+ self,
74
+ # unk_token="[UNK]",
75
+ # sep_token="[SEP]",
76
+ # pad_token="[PAD]",
77
+ # cls_token="[CLS]",
78
+ # mask_token="[MASK]",
79
+ **kwargs
80
+ ):
81
+ """Constructs a SmilesTokenizer.
82
+
83
+ Parameters
84
+ ----------
85
+ vocab_file: str
86
+ Path to a SMILES character per line vocabulary file.
87
+ Default vocab file is found in deepchem/feat/tests/data/vocab.txt
88
+ """
89
+
90
+ vocab_file = os.path.join(os.path.dirname(__file__), "data", "vocab.txt")
91
+
92
+ super().__init__(vocab_file, **kwargs)
93
+
94
+ self.sos = "[SOS]"
95
+ self.eos = "[EOS]"
96
+
97
+ if not os.path.isfile(vocab_file):
98
+ raise ValueError("Can't find a vocab file at path '{}'.".format(vocab_file))
99
+ self.vocab = load_vocab(vocab_file)
100
+ self.highest_unused_index = max(
101
+ [i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")]
102
+ )
103
+ self.ids_to_tokens = collections.OrderedDict(
104
+ [(ids, tok) for tok, ids in self.vocab.items()]
105
+ )
106
+ self.basic_tokenizer = BasicSmilesTokenizer()
107
+
108
+ @property
109
+ def vocab_size(self):
110
+ return len(self.vocab)
111
+
112
+ @property
113
+ def vocab_list(self):
114
+ return list(self.vocab.keys())
115
+
116
+ def _tokenize(self, text: str):
117
+ """
118
+ Tokenize a string into a list of tokens.
119
+
120
+ Parameters
121
+ ----------
122
+ text: str
123
+ Input string sequence to be tokenized.
124
+ """
125
+
126
+ split_tokens = [token for token in self.basic_tokenizer.tokenize(text)]
127
+ return split_tokens
128
+
129
+ def _convert_token_to_id(self, token):
130
+ """
131
+ Converts a token (str/unicode) in an id using the vocab.
132
+
133
+ Parameters
134
+ ----------
135
+ token: str
136
+ String token from a larger sequence to be converted to a numerical id.
137
+ """
138
+
139
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
140
+
141
+ def _convert_id_to_token(self, index):
142
+ """
143
+ Converts an index (integer) in a token (string/unicode) using the vocab.
144
+
145
+ Parameters
146
+ ----------
147
+ index: int
148
+ Integer index to be converted back to a string-based token as part of a larger sequence.
149
+ """
150
+
151
+ return self.ids_to_tokens.get(index, self.unk_token)
152
+
153
+ def convert_tokens_to_string(self, tokens: List[str]):
154
+ """Converts a sequence of tokens (string) in a single string.
155
+
156
+ Parameters
157
+ ----------
158
+ tokens: List[str]
159
+ List of tokens for a given string sequence.
160
+
161
+ Returns
162
+ -------
163
+ out_string: str
164
+ Single string from combined tokens.
165
+ """
166
+
167
+ out_string: str = " ".join(tokens).replace(" ##", "").strip()
168
+ return out_string
169
+
170
+ def add_special_tokens_ids_single_sequence(self, token_ids: List[int]):
171
+ """
172
+ Adds special tokens to the a sequence for sequence classification tasks.
173
+ A BERT sequence has the following format: [CLS] X [SEP]
174
+
175
+ Parameters
176
+ ----------
177
+
178
+ token_ids: list[int]
179
+ list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
180
+ """
181
+
182
+ return [self.cls_token_id] + token_ids + [self.sep_token_id]
183
+
184
+ def add_special_tokens_single_sequence(self, tokens: List[str]):
185
+ """
186
+ Adds special tokens to the a sequence for sequence classification tasks.
187
+ A BERT sequence has the following format: [CLS] X [SEP]
188
+
189
+ Parameters
190
+ ----------
191
+ tokens: List[str]
192
+ List of tokens for a given string sequence.
193
+
194
+ """
195
+ return [self.cls_token] + tokens + [self.sep_token]
196
+
197
+ def add_special_tokens_ids_sequence_pair(
198
+ self, token_ids_0: List[int], token_ids_1: List[int]
199
+ ) -> List[int]:
200
+ """
201
+ Adds special tokens to a sequence pair for sequence classification tasks.
202
+ A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
203
+
204
+ Parameters
205
+ ----------
206
+ token_ids_0: List[int]
207
+ List of ids for the first string sequence in the sequence pair (A).
208
+
209
+ token_ids_1: List[int]
210
+ List of tokens for the second string sequence in the sequence pair (B).
211
+ """
212
+
213
+ sep = [self.sep_token_id]
214
+ cls = [self.cls_token_id]
215
+
216
+ return cls + token_ids_0 + sep + token_ids_1 + sep
217
+
218
+ def add_padding_tokens(
219
+ self, token_ids: List[int], length: int, right: bool = True
220
+ ) -> List[int]:
221
+ """
222
+ Adds padding tokens to return a sequence of length max_length.
223
+ By default padding tokens are added to the right of the sequence.
224
+
225
+ Parameters
226
+ ----------
227
+ token_ids: list[int]
228
+ list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
229
+
230
+ length: int
231
+
232
+ right: bool (True by default)
233
+
234
+ Returns
235
+ ----------
236
+ token_ids :
237
+ list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
238
+
239
+ padding: int
240
+ Integer to be added as padding token
241
+
242
+ """
243
+ padding = [self.pad_token_id] * (length - len(token_ids))
244
+
245
+ if right:
246
+ return token_ids + padding
247
+ else:
248
+ return padding + token_ids
249
+
250
+ def save_vocabulary(
251
+ self, vocab_path: str
252
+ ): # -> tuple[str]: doctest issue raised with this return type annotation
253
+ """
254
+ Save the tokenizer vocabulary to a file.
255
+
256
+ Parameters
257
+ ----------
258
+ vocab_path: obj: str
259
+ The directory in which to save the SMILES character per line vocabulary file.
260
+ Default vocab file is found in deepchem/feat/tests/data/vocab.txt
261
+
262
+ Returns
263
+ ----------
264
+ vocab_file: :obj:`Tuple(str)`:
265
+ Paths to the files saved.
266
+ typle with string to a SMILES character per line vocabulary file.
267
+ Default vocab file is found in deepchem/feat/tests/data/vocab.txt
268
+
269
+ """
270
+ index = 0
271
+ if os.path.isdir(vocab_path):
272
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
273
+ else:
274
+ vocab_file = vocab_path
275
+ with open(vocab_file, "w", encoding="utf-8") as writer:
276
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
277
+ if index != token_index:
278
+ logger.warning(
279
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
280
+ " Please check that the vocabulary is not corrupted!".format(
281
+ vocab_file
282
+ )
283
+ )
284
+ index = token_index
285
+ writer.write(token + "\n")
286
+ index += 1
287
+ return (vocab_file,)
288
+
289
+
290
+ class BasicSmilesTokenizer(object):
291
+ """
292
+
293
+ Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used
294
+ when a tokenizer that does not require the transformers library by HuggingFace is required.
295
+
296
+ Examples
297
+ --------
298
+ >>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer
299
+ >>> tokenizer = BasicSmilesTokenizer()
300
+ >>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O"))
301
+ ['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O']
302
+
303
+
304
+ References
305
+ ----------
306
+ .. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
307
+ ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
308
+ 1572-1583 DOI: 10.1021/acscentsci.9b00576
309
+
310
+ """
311
+
312
+ def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN):
313
+ """Constructs a BasicSMILESTokenizer.
314
+ Parameters
315
+ ----------
316
+
317
+ regex: string
318
+ SMILES token regex
319
+
320
+ """
321
+ self.regex_pattern = regex_pattern
322
+ self.regex = re.compile(self.regex_pattern)
323
+
324
+ def tokenize(self, text):
325
+ """Basic Tokenization of a SMILES."""
326
+ tokens = [token for token in self.regex.findall(text)]
327
+ return tokens
328
+
329
+
330
+ def load_vocab(vocab_file):
331
+ """Loads a vocabulary file into a dictionary."""
332
+ vocab = collections.OrderedDict()
333
+ with open(vocab_file, "r", encoding="utf-8") as reader:
334
+ tokens = reader.readlines()
335
+ for index, token in enumerate(tokens):
336
+ token = token.rstrip("\n")
337
+ vocab[token] = index
338
+ return vocab
339
+
340
+
341
+ class BasicSmilesTokenizer(object):
342
+ """
343
+
344
+ Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used
345
+ when a tokenizer that does not require the transformers library by HuggingFace is required.
346
+
347
+ Examples
348
+ --------
349
+ >>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer
350
+ >>> tokenizer = BasicSmilesTokenizer()
351
+ >>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O"))
352
+ ['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O']
353
+
354
+
355
+ References
356
+ ----------
357
+ .. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
358
+ ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
359
+ 1572-1583 DOI: 10.1021/acscentsci.9b00576
360
+
361
+ """
362
+
363
+ def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN):
364
+ """Constructs a BasicSMILESTokenizer.
365
+ Parameters
366
+ ----------
367
+
368
+ regex: string
369
+ SMILES token regex
370
+
371
+ """
372
+ self.regex_pattern = regex_pattern
373
+ self.regex = re.compile(self.regex_pattern)
374
+
375
+ def tokenize(self, text):
376
+ """Basic Tokenization of a SMILES."""
377
+ tokens = [token for token in self.regex.findall(text)]
378
+ return tokens
379
+
380
+
381
+ def load_vocab(vocab_file):
382
+ """Loads a vocabulary file into a dictionary."""
383
+ vocab = collections.OrderedDict()
384
+ with open(vocab_file, "r", encoding="utf-8") as reader:
385
+ tokens = reader.readlines()
386
+ for index, token in enumerate(tokens):
387
+ token = token.rstrip("\n")
388
+ vocab[token] = index
389
+ return vocab
390
+
391
+
392
+ if __name__ == "__main__":
393
+ current_dir = os.path.dirname(os.path.realpath(__file__))
394
+ vocab_path = os.path.join(current_dir, "tests/data", "vocab.txt")
395
+ tokenizer = SmilesTokenizer()
396
+
397
+ tokens = tokenizer.encode(
398
+ "CN1CC[C@]23[C@@H]4[C@H]1CC5=C2C(=C(C=C5)O)O[C@H]3[C@H](C=C4)O"
399
+ )
400
+ print([tokenizer._convert_id_to_token(t) for t in tokens])
401
+
402
+ enc = tokenizer.encode("CC=O")
403
+ print(enc)
404
+ print(tokenizer.decode(enc, skip_special_tokens=True).replace(" ", ""))
torch2-env.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: torch2-llamol
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - python=3.8
9
+ - torchaudio
10
+ - pytorch
11
+ - torchvision
12
+ - pytorch-cuda
13
+ - rdkit
14
+ - ca-certificates
15
+ - certifi
16
+ - openssl
17
+ - openbabel
18
+ - ipykernel
19
+ pip:
20
+ - tqdm
21
+ - transformers
22
+ - pandas
23
+ - matplotlib
24
+ - seaborn
25
+ - hydra-core
26
+ - swifter
27
+ - pyarrow
28
+ - ipywidgets
29
+ - dask
train.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trainer import (
2
+ IOConfig,
3
+ LoaderConfig,
4
+ Trainer,
5
+ TrainerArgs,
6
+ ModelArgs,
7
+ ContextArgs,
8
+ OptimizerConfig,
9
+ )
10
+ from torch.distributed.elastic.multiprocessing.errors import record
11
+
12
+ import hydra
13
+ from omegaconf import DictConfig, OmegaConf
14
+ import logging
15
+ import sys
16
+ import os
17
+ import torch
18
+
19
+
20
+ def setup_logger(run_name: str, log_path: str):
21
+ ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
22
+ if ddp:
23
+ ddp_rank = int(os.environ["RANK"])
24
+ ddp_local_rank = int(os.environ["LOCAL_RANK"])
25
+ ddp_world_size = int(os.environ["WORLD_SIZE"])
26
+
27
+ formatter = logging.Formatter(
28
+ f"[%(levelname)s] DDP[{ddp_rank},{ddp_local_rank},{ddp_world_size}] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s",
29
+ datefmt="%Y-%m-%d %H:%M:%S",
30
+ )
31
+ else:
32
+ formatter = logging.Formatter(
33
+ r"[%(levelname)s] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s",
34
+ datefmt="%Y-%m-%d %H:%M:%S",
35
+ )
36
+
37
+ stream_handler = logging.StreamHandler(sys.stdout)
38
+ stream_handler.setFormatter(formatter)
39
+
40
+ os.makedirs(log_path, exist_ok=True)
41
+ file_handler = logging.FileHandler(os.path.join(log_path, f"train_{run_name}.log"))
42
+ file_handler.setFormatter(formatter)
43
+
44
+ logging.basicConfig(level=logging.INFO, handlers=[stream_handler, file_handler])
45
+
46
+ return logging.getLogger()
47
+
48
+
49
+ @record
50
+ @hydra.main(version_base=None, config_path="config", config_name="config")
51
+ def main(cfg: DictConfig) -> None:
52
+ logger = setup_logger(
53
+ cfg.get("run_name", "default"), cfg.get("io", {"out_dir": "out"})["out_dir"]
54
+ )
55
+
56
+ logger.info("Using config")
57
+ logger.info(cfg)
58
+
59
+ cfg = cfg["train"]
60
+ io_conf = IOConfig(**cfg.get("io", {}))
61
+ loader_conf = LoaderConfig(**cfg.get("loader", {}))
62
+ model_args = ModelArgs(**cfg.get("model", {}))
63
+ ctx_args = ContextArgs(**cfg.get("context", {}))
64
+ optmizer_conf = OptimizerConfig(**cfg.get("optimizer", {}))
65
+ train_args = TrainerArgs(
66
+ io_conf=io_conf,
67
+ loader_conf=loader_conf,
68
+ model_conf=model_args,
69
+ context_conf=ctx_args,
70
+ optimizer_conf=optmizer_conf,
71
+ run_name=cfg.get("label", "train_run"),
72
+ )
73
+
74
+ # When training on cpu / testing to not max out all cpu cores
75
+ torch.set_num_threads(8)
76
+
77
+ trainer = Trainer(
78
+ train_args=train_args,
79
+ dtype=cfg.get("dtype", "float16"),
80
+ compile=cfg.get("compile", False),
81
+ )
82
+ should_profile = cfg.get("profile", False)
83
+
84
+ if should_profile:
85
+ with torch.profiler.profile(
86
+ activities=[
87
+ torch.profiler.ProfilerActivity.CPU,
88
+ torch.profiler.ProfilerActivity.CUDA,
89
+ ]
90
+ ) as p:
91
+ trainer.train()
92
+
93
+ print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
94
+
95
+ else:
96
+ trainer.train()
97
+
98
+
99
+ if __name__ == "__main__":
100
+ # python train.py train=llama2-M-Full train.model.dim=1024
101
+ main()
trainLLamaMol.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --mem=32gb # Total memory limit
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=1
5
+ #SBATCH --cpus-per-task=2
6
+ #SBATCH --partition=<YOUR PARTITION>
7
+ #SBATCH --gres=gpu:a100:1
8
+ #SBATCH --time=2-00:00:00 # Time limit 2-hrs:min:sec days
9
+
10
+ export CUDA_VISIBLE_DEVICES=0
11
+
12
+ # TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
13
+ conda activate FULL_PATH_TO_CONDA/torch2-llamol
14
+ module load CUDA/11.7.0
15
+ module load GCC/7.1.0-2.28
16
+
17
+ cd ~/llama2-mol
18
+
19
+ srun python train.py train=llama2-M-Full-RSS > "train_runs/run_$SLURM_JOB_ID.out"
trainLLamaMolDDPSingleNode.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --mem=32gb # Total memory limit
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=<HOW MANY GPUS>
5
+ #SBATCH --cpus-per-task=2
6
+ #SBATCH --partition=<YOUR PARTITION>
7
+ #SBATCH --gres=gpu:a100:<HOW MANY GPUS>
8
+ #SBATCH --time=2-00:00:00 # Time limit 2-hrs:min:sec days
9
+
10
+ export WORLD_SIZE=2
11
+ export OMP_NUM_THREADS=8
12
+ ### get the first node name as master address - customized for vgg slurm
13
+ ### e.g. master(gnodee[2-5],gnoded1) == gnodee2
14
+ echo "NODELIST="${SLURM_NODELIST}
15
+ master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16
+ PORT=54357
17
+ export MASTER_ADDR="$master_addr:$PORT"
18
+
19
+
20
+ # TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
21
+ conda activate FULL_PATH_TO_CONDA/torch2-llamol
22
+ module load CUDA/11.7.0
23
+ module load GCC/8.3.0
24
+
25
+ # TODO: Change this to the folder you cloned the repo in
26
+ cd ~/llamol
27
+
28
+ srun torchrun --standalone --max_restarts=3 --nnodes=1 --nproc_per_node=2 --rdzv-id=$SLURM_JOB_ID --rdzv-backend=c10d --rdzv-endpoint="$master_addr:$PORT" train.py train=llama2-M-Full > "train_runs/run_$SLURM_JOB_ID.out"
trainer.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional, Tuple, List, Union
3
+ from fragment_creator import fragment_creator_factory
4
+
5
+ from model import ContextArgs, ModelArgs
6
+ from tqdm import tqdm
7
+ import math
8
+ import os
9
+ import time
10
+ from contextlib import nullcontext
11
+ from datetime import datetime
12
+ from functools import partial
13
+
14
+ import torch
15
+ import numpy as np
16
+ from model import ContextArgs, Transformer, ModelArgs
17
+ from torch.distributed import destroy_process_group, init_process_group
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+
20
+ from preprocess_dataset import SmilesTask
21
+ from tokenizer import SmilesTokenizer
22
+
23
+ import logging
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class IOConfig:
30
+ # I/O
31
+ out_dir: str = "out"
32
+ eval_interval: int = 500
33
+ log_interval: int = 10
34
+ eval_iters: int = 25
35
+ eval_only: bool = False # if True, script exits right after the first eval
36
+ always_save_checkpoint: bool = (
37
+ False # if True, always save a checkpoint after each eval
38
+ )
39
+ init_from: str = "scratch" # 'scratch' or 'resume'
40
+ resume_when_snapshot_available: bool = True
41
+
42
+
43
+ @dataclass
44
+ class LoaderConfig:
45
+ # data
46
+ batch_size: int = (
47
+ 384 # if gradient_accumulation_steps > 1, this is the micro-batch size
48
+ )
49
+ max_seq_len: int = 768
50
+ dataset: str = "smiles"
51
+ processed_dataset_ckpt: str = "processed_dataset_None.pkl"
52
+ fragment_creator: Union[str, None] = None
53
+
54
+
55
+ # dim = 256
56
+ # n_layers = 8
57
+ # n_heads = 8
58
+ # multiple_of = 128
59
+ # dropout = 0.1
60
+
61
+
62
+ @dataclass
63
+ class OptimizerConfig:
64
+ # adamw optimizer
65
+ gradient_accumulation_steps: int = 4 # used to simulate larger batch sizes
66
+ learning_rate: float = 1e-4 # max learning rate
67
+ max_iters: int = 100000 # total number of training iterations
68
+ weight_decay: float = 1e-1
69
+ beta1: float = 0.9
70
+ beta2: float = 0.95
71
+ grad_clip: float = 1.0 # clip gradients at this value, or disable if == 0.0
72
+ # learning rate decay settings
73
+ decay_lr: bool = True # whether to decay the learning rate
74
+ warmup_iters: int = 1000 # how many steps to warm up for
75
+
76
+ lr_decay_iters: int = 100000 # should be ~= max_iters per Chinchilla
77
+ min_lr: float = (
78
+ 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
79
+ )
80
+
81
+
82
+ @dataclass
83
+ class TrainerArgs:
84
+ # Input / Output
85
+ io_conf: IOConfig
86
+
87
+ # Loader Configs
88
+ loader_conf: LoaderConfig
89
+
90
+ # Transformer Args
91
+ model_conf: ModelArgs
92
+ context_conf: ContextArgs
93
+
94
+ # Optimizer
95
+ optimizer_conf: OptimizerConfig
96
+
97
+ run_name: str
98
+
99
+
100
+ class Trainer:
101
+ def __init__(
102
+ self, train_args: TrainerArgs, dtype: str = "float16", compile: bool = False
103
+ ) -> None:
104
+ self.train_conf = train_args
105
+ self.dtype = dtype
106
+ self.compile = compile
107
+ # system
108
+ self.run_name = train_args.run_name
109
+ self.device = (
110
+ "cuda:0" if torch.cuda.is_available() else "cpu"
111
+ ) # "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
112
+
113
+ self.CKPT_PT = f"{self.run_name}.pt"
114
+ self.SNAPSHOT_PT = f"snapshot_{self.run_name}.pt"
115
+
116
+ def _init_ddp_if_possible(self):
117
+ # various inits, derived attributes, I/O setup
118
+ self.ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
119
+ if self.ddp:
120
+ logger.info(f"Using ddp!")
121
+ init_process_group(backend="nccl")
122
+ self.ddp_rank = int(os.environ["RANK"])
123
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
124
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"])
125
+ logger.info(f"{self.ddp_rank}, {self.ddp_local_rank},{self.ddp_world_size}")
126
+
127
+ self.device = f"cuda:{self.ddp_local_rank}"
128
+ torch.cuda.set_device(self.device)
129
+ self.master_process = (
130
+ self.ddp_rank == 0
131
+ ) # this process will do logging, checkpointing etc.
132
+
133
+ logger.info(f"Is master process {self.device}? {self.master_process}")
134
+ self.seed_offset = self.ddp_rank # each process gets a different seed
135
+ # world_size number of processes will be training simultaneously, so we can scale
136
+ # down the desired gradient accumulation iterations per process proportionally
137
+ assert (
138
+ self.train_conf.optimizer_conf.gradient_accumulation_steps
139
+ % self.ddp_world_size
140
+ == 0
141
+ )
142
+ self.train_conf.optimizer_conf.gradient_accumulation_steps //= (
143
+ self.ddp_world_size
144
+ )
145
+ else:
146
+ # if not ddp, we are running on a single gpu, and one process
147
+ self.master_process = True
148
+ self.seed_offset = 0
149
+ self.ddp_world_size = 1
150
+
151
+ def _init_train(self):
152
+ self.tokens_per_iter = (
153
+ self.train_conf.optimizer_conf.gradient_accumulation_steps
154
+ * self.ddp_world_size
155
+ * self.train_conf.loader_conf.batch_size
156
+ * self.train_conf.loader_conf.max_seq_len
157
+ )
158
+ if self.master_process:
159
+ logger.info(f"tokens per iteration will be: {self.tokens_per_iter:,}")
160
+ logger.info(
161
+ f"breaks down as: {self.train_conf.optimizer_conf.gradient_accumulation_steps} grad accum steps * {self.ddp_world_size} processes * {self.train_conf.loader_conf.batch_size} batch size * {self.train_conf.loader_conf.max_seq_len } max seq len"
162
+ )
163
+
164
+ if self.master_process:
165
+ os.makedirs(self.train_conf.io_conf.out_dir, exist_ok=True)
166
+
167
+ torch.manual_seed(1337 + self.seed_offset)
168
+ np.random.seed(1337 + self.seed_offset)
169
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
170
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
171
+ self.device_type = (
172
+ "cuda" if "cuda" in self.device else "cpu"
173
+ ) # for later use in torch.autocast
174
+ # note: float16 data type will automatically use a GradScaler
175
+ ptdtype = {
176
+ "float32": torch.float32,
177
+ "bfloat16": torch.bfloat16,
178
+ "float16": torch.float16,
179
+ }[self.dtype]
180
+ self.ctx = (
181
+ nullcontext()
182
+ if self.device_type == "cpu"
183
+ else torch.amp.autocast(device_type=self.device_type, dtype=ptdtype)
184
+ )
185
+ # task-specific setup
186
+ task = {"smiles": SmilesTask}[self.train_conf.loader_conf.dataset]
187
+ self.iter_batches = partial(
188
+ task.iter_batches,
189
+ batch_size=self.train_conf.loader_conf.batch_size,
190
+ device=self.device,
191
+ context_keys=self.train_conf.context_conf.context_keys,
192
+ num_workers=0,
193
+ dataset=self.train_conf.loader_conf.processed_dataset_ckpt,
194
+ fragment_creator=fragment_creator_factory(
195
+ self.train_conf.loader_conf.fragment_creator
196
+ ),
197
+ )
198
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
199
+ self.iter_num = 0
200
+ self.best_val_loss = 1e9
201
+ self.epoch = 1
202
+
203
+ self.tokenizer = SmilesTokenizer()
204
+
205
+ has_resumed = False
206
+ if (
207
+ self.train_conf.io_conf.init_from == "resume"
208
+ or self.train_conf.io_conf.resume_when_snapshot_available
209
+ ):
210
+ snapshot_path = os.path.join(
211
+ self.train_conf.io_conf.out_dir, self.SNAPSHOT_PT
212
+ )
213
+ if os.path.exists(snapshot_path):
214
+ has_resumed = True
215
+ logger.info(f"Resuming training from {self.train_conf.io_conf.out_dir}")
216
+ # resume training from a checkpoint.
217
+ ckpt_path = os.path.join(self.train_conf.io_conf.out_dir, self.CKPT_PT)
218
+ self.model = Transformer.load(ckpt_path, device=self.device)
219
+ snapshot = torch.load(snapshot_path, map_location=self.device)
220
+ self.iter_num = snapshot["iter_num"]
221
+ self.best_val_loss = snapshot["best_val_loss"]
222
+ self.epoch = snapshot["epoch"]
223
+
224
+ if self.train_conf.io_conf.init_from == "scratch" and not has_resumed:
225
+ # init a new model from scratch
226
+ logger.info("Initializing a new model from scratch")
227
+ logger.info(self.device)
228
+
229
+ model_conf = self.train_conf.model_conf
230
+ model_conf.vocab_size = self.tokenizer.vocab_size
231
+
232
+ self.model = Transformer(model_conf, self.train_conf.context_conf).to(
233
+ self.device
234
+ )
235
+ logger.info(
236
+ f"Number of params: {self.model.getNumberParams()} Number Trainable Params: {self.model.getNumberTrainableParams()}"
237
+ )
238
+
239
+ # else:
240
+ # raise ValueError(
241
+ # f"Could not find option: {self.train_conf.io_conf.init_from}. Use either 'scratch' or 'resume'"
242
+ # )
243
+
244
+ self.model = self.model.to(self.device)
245
+
246
+ # initialize a GradScaler. If enabled=False scaler is a no-op
247
+ self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == "float16"))
248
+
249
+ # optimizer
250
+ self.optimizer = self.model.configure_optimizers(
251
+ self.train_conf.optimizer_conf.weight_decay,
252
+ self.train_conf.optimizer_conf.learning_rate,
253
+ (
254
+ self.train_conf.optimizer_conf.beta1,
255
+ self.train_conf.optimizer_conf.beta2,
256
+ ),
257
+ self.device_type,
258
+ )
259
+
260
+ if (
261
+ self.train_conf.io_conf.init_from == "resume"
262
+ and "optimizer_state" in snapshot
263
+ ):
264
+ logger.info("Loading optimizer state from snapshot")
265
+ self.optimizer.load_state_dict(snapshot["optimizer_state"])
266
+ snapshot = None # free up memory
267
+
268
+ # compile the model
269
+ if self.compile:
270
+ logger.info("compiling the model... (takes a ~minute)")
271
+ self.unoptimized_model = self.model
272
+ # NOTE: This is REALLY REALLY slow in our case, as the shapes are different in each epoch.
273
+ # So it recompiles every batch ._.
274
+ self.model = torch.compile(
275
+ self.model, dynamic=False
276
+ ) # requires PyTorch 2.0
277
+
278
+ # wrap model into DDP container
279
+ if self.ddp:
280
+ # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
281
+ # construction time since NCCL does not support `ComplexFloat`
282
+ prefix = "_orig_mod." if compile else ""
283
+ self.model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
284
+ self.model = DDP(self.model, device_ids=[self.ddp_local_rank])
285
+
286
+ # helps estimate an arbitrarily accurate loss over either split using many batches
287
+ @torch.no_grad()
288
+ def estimate_loss(self):
289
+ out = {}
290
+ self.model.eval()
291
+ for split in ["train", "val"]:
292
+ batch_iter = self.iter_batches(split)
293
+ losses = torch.zeros(self.train_conf.io_conf.eval_iters) # keep on CPU
294
+ for k in tqdm(
295
+ range(self.train_conf.io_conf.eval_iters),
296
+ total=self.train_conf.io_conf.eval_iters,
297
+ desc="Eval",
298
+ ):
299
+ try:
300
+ X = next(batch_iter)
301
+ with self.ctx:
302
+ # logger.info(model)
303
+ # logger.info(X["src"].device)
304
+
305
+ logits = self.model(
306
+ X["src"],
307
+ targets=X["tgt"],
308
+ context=X["context"],
309
+ fragment=X["fragment"],
310
+ )
311
+
312
+ loss = self.raw_model.last_loss
313
+ losses[k] = loss.item()
314
+ except StopIteration:
315
+ logger.info("Early Eval Stop")
316
+
317
+ out[split] = losses.mean()
318
+ self.model.train()
319
+ return out
320
+
321
+ # learning rate decay scheduler (cosine with warmup)
322
+ def get_lr(self, it: int):
323
+ warmup_iters = self.train_conf.optimizer_conf.warmup_iters
324
+ learning_rate = self.train_conf.optimizer_conf.learning_rate
325
+ lr_decay_iters = self.train_conf.optimizer_conf.lr_decay_iters
326
+ min_lr = self.train_conf.optimizer_conf.min_lr
327
+
328
+ # 1) linear warmup for warmup_iters steps
329
+ if it < warmup_iters:
330
+ return learning_rate * it / warmup_iters
331
+ # 2) if it > lr_decay_iters, return min learning rate
332
+ if it > lr_decay_iters:
333
+ return min_lr
334
+ # 3) in between, use cosine decay down to min learning rate
335
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
336
+ assert 0 <= decay_ratio <= 1
337
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
338
+ return min_lr + coeff * (learning_rate - min_lr)
339
+
340
+ def train(self):
341
+ self._init_ddp_if_possible()
342
+ self._init_train()
343
+
344
+ # training loop
345
+ train_batch_iter = self.iter_batches("train")
346
+ X = next(train_batch_iter) # fetch the very first batch
347
+ t0 = time.time()
348
+ local_iter_num = 0 # number of iterations in the lifetime of this process
349
+ self.raw_model = (
350
+ self.model.module if self.ddp else self.model
351
+ ) # unwrap DDP container if needed
352
+ running_mfu = -1.0
353
+
354
+ gradient_accumulation_steps = (
355
+ self.train_conf.optimizer_conf.gradient_accumulation_steps
356
+ )
357
+ while True:
358
+ # determine and set the learning rate for this iteration
359
+ lr = (
360
+ self.get_lr(self.iter_num)
361
+ if self.train_conf.optimizer_conf.decay_lr
362
+ else self.train_conf.optimizer_conf.learning_rate
363
+ )
364
+ for param_group in self.optimizer.param_groups:
365
+ param_group["lr"] = lr
366
+
367
+ # evaluate the loss on train/val sets and write checkpoints
368
+ if (
369
+ self.iter_num % self.train_conf.io_conf.eval_interval == 0
370
+ and self.master_process
371
+ and self.iter_num != 0
372
+ ):
373
+ logger.info(
374
+ f"Estimating loss for master_process({self.master_process}) on iter {self.iter_num}"
375
+ )
376
+ losses = self.estimate_loss()
377
+ logger.info(
378
+ f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
379
+ )
380
+ log_dict = {
381
+ "iter": self.iter_num,
382
+ "tokens": self.iter_num * self.tokens_per_iter,
383
+ "loss/train": losses["train"],
384
+ "loss/val": losses["val"],
385
+ "lr": lr,
386
+ "mfu": running_mfu * 100, # convert to percentage
387
+ }
388
+ logger.info(f"{log_dict}")
389
+
390
+ if (
391
+ losses["val"] < self.best_val_loss
392
+ or self.train_conf.io_conf.always_save_checkpoint
393
+ ):
394
+ self.best_val_loss = losses["val"]
395
+ if self.iter_num > 0:
396
+ logger.info(
397
+ f"saving checkpoint to {self.train_conf.io_conf.out_dir}"
398
+ )
399
+ self.raw_model.save(
400
+ os.path.join(self.train_conf.io_conf.out_dir, self.CKPT_PT)
401
+ )
402
+
403
+ torch.save(
404
+ {
405
+ "iter_num": self.iter_num,
406
+ "epoch": self.epoch,
407
+ "best_val_loss": self.best_val_loss,
408
+ "optimizer_state": self.optimizer.state_dict(),
409
+ },
410
+ os.path.join(
411
+ self.train_conf.io_conf.out_dir, self.SNAPSHOT_PT
412
+ ),
413
+ )
414
+
415
+ if self.iter_num == 0 and self.train_conf.io_conf.eval_only:
416
+ break
417
+
418
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
419
+ # and using the GradScaler if data type is float16
420
+ for micro_step in range(gradient_accumulation_steps):
421
+ if self.ddp:
422
+ # in DDP training we only need to sync gradients at the last micro step.
423
+ # the official way to do this is with model.no_sync() context manager, but
424
+ # I really dislike that this bloats the code and forces us to repeat code
425
+ # looking at the source of that context manager, it just toggles this variable
426
+ self.model.require_backward_grad_sync = (
427
+ micro_step == gradient_accumulation_steps - 1
428
+ )
429
+ with self.ctx:
430
+ context = X["context"]
431
+
432
+ fragment = X["fragment"]
433
+
434
+ # SCL (Stochastic context learning) algorithm
435
+ if np.random.random() < 0.15 or fragment is None:
436
+ fragment = None
437
+
438
+ # NOTE: random delete one context or more context columns
439
+ current_context_keys = list(context.keys())
440
+ for k in current_context_keys:
441
+ if np.random.random() < 0.15:
442
+ del context[k]
443
+
444
+ logits = self.model(
445
+ X["src"], targets=X["tgt"], context=context, fragment=fragment
446
+ )
447
+ loss = self.raw_model.last_loss
448
+ loss = loss / gradient_accumulation_steps
449
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
450
+ try:
451
+ X = next(train_batch_iter)
452
+
453
+ except StopIteration:
454
+ # StopIteration is thrown if dataset ends
455
+ # reinitialize data loader
456
+ logger.info(f"Done Epoch {self.epoch}")
457
+ train_batch_iter = self.iter_batches("train")
458
+ X = next(train_batch_iter)
459
+ self.epoch += 1
460
+
461
+ # backward pass, with gradient scaling if training in fp16
462
+ self.scaler.scale(loss).backward()
463
+ # logger.info(loss)
464
+ # clip the gradient
465
+ if self.train_conf.optimizer_conf.grad_clip != 0.0:
466
+ self.scaler.unscale_(self.optimizer)
467
+ torch.nn.utils.clip_grad_norm_(
468
+ self.model.parameters(), self.train_conf.optimizer_conf.grad_clip
469
+ )
470
+ # step the optimizer and scaler if training in fp16
471
+ self.scaler.step(self.optimizer)
472
+ self.scaler.update()
473
+ # flush the gradients as soon as we can, no need for this memory anymore
474
+ self.optimizer.zero_grad(set_to_none=True)
475
+
476
+ # timing and logging
477
+ t1 = time.time()
478
+ dt = t1 - t0
479
+ t0 = t1
480
+
481
+ if (
482
+ self.iter_num % self.train_conf.io_conf.log_interval == 0
483
+ and self.master_process
484
+ ):
485
+ # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
486
+ lossf = loss.item() * gradient_accumulation_steps
487
+ if local_iter_num >= 5: # let the training loop settle a bit
488
+ mfu = self.raw_model.estimate_mfu(
489
+ self.train_conf.loader_conf.batch_size
490
+ * gradient_accumulation_steps,
491
+ dt,
492
+ )
493
+ running_mfu = (
494
+ mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
495
+ )
496
+ logger.info(
497
+ f"{self.iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
498
+ )
499
+ self.iter_num += 1
500
+ local_iter_num += 1
501
+
502
+ # termination conditions
503
+
504
+ if self.iter_num > self.train_conf.optimizer_conf.max_iters:
505
+ logger.info("Done with training iters!")
506
+ break
507
+
508
+ if self.ddp:
509
+ destroy_process_group()
510
+
511
+
512
+ if __name__ == "__main__":
513
+ pass