Spaces:
Running
Running
vshirasuna
commited on
Commit
•
159c02c
1
Parent(s):
32b373a
Added normalization of SMILES
Browse files
models/smi_ted/smi_ted_light/load.py
CHANGED
@@ -19,6 +19,13 @@ from huggingface_hub import hf_hub_download
|
|
19 |
# Data
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Standard library
|
24 |
from functools import partial
|
@@ -30,6 +37,17 @@ from tqdm import tqdm
|
|
30 |
tqdm.pandas()
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
class MolTranBertTokenizer(BertTokenizer):
|
34 |
def __init__(self, vocab_file: str = '',
|
35 |
do_lower_case=False,
|
@@ -477,9 +495,17 @@ class Smi_ted(nn.Module):
|
|
477 |
if self.is_cuda_available:
|
478 |
self.encoder.cuda()
|
479 |
self.decoder.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
481 |
# tokenizer
|
482 |
-
idx, mask = self.tokenize(smiles)
|
483 |
|
484 |
###########
|
485 |
# Encoder #
|
@@ -515,6 +541,30 @@ class Smi_ted(nn.Module):
|
|
515 |
# reconstruct tokens
|
516 |
pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
|
517 |
pred_ids = torch.argmax(pred_ids, axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
|
519 |
return ((true_ids, pred_ids), # tokens
|
520 |
(true_cte, pred_cte), # token embeddings
|
@@ -548,9 +598,14 @@ class Smi_ted(nn.Module):
|
|
548 |
|
549 |
# handle single str or a list of str
|
550 |
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
|
551 |
-
|
552 |
-
|
|
|
|
|
|
|
|
|
553 |
# process in batches
|
|
|
554 |
embeddings = [
|
555 |
self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
|
556 |
for batch in tqdm(np.array_split(smiles, n_split))
|
@@ -562,8 +617,13 @@ class Smi_ted(nn.Module):
|
|
562 |
torch.cuda.empty_cache()
|
563 |
gc.collect()
|
564 |
|
|
|
|
|
|
|
|
|
|
|
565 |
if return_torch:
|
566 |
-
return torch.tensor(
|
567 |
return pd.DataFrame(flat_list)
|
568 |
|
569 |
def decode(self, smiles_embeddings):
|
@@ -607,6 +667,7 @@ def load_smi_ted(folder="./smi_ted_light",
|
|
607 |
):
|
608 |
tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
|
609 |
model = Smi_ted(tokenizer)
|
|
|
610 |
repo_id = "ibm/materials.smi-ted"
|
611 |
filename = "smi-ted-Light_40.pt"
|
612 |
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
@@ -614,6 +675,4 @@ def load_smi_ted(folder="./smi_ted_light",
|
|
614 |
model.eval()
|
615 |
print('Vocab size:', len(tokenizer.vocab))
|
616 |
print(f'[INFERENCE MODE - {str(model)}]')
|
617 |
-
return model
|
618 |
-
|
619 |
-
|
|
|
19 |
# Data
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
# Chemistry
|
25 |
+
from rdkit import Chem
|
26 |
+
from rdkit.Chem import PandasTools
|
27 |
+
from rdkit.Chem import Descriptors
|
28 |
+
PandasTools.RenderImagesInAllDataFrames(True)
|
29 |
|
30 |
# Standard library
|
31 |
from functools import partial
|
|
|
37 |
tqdm.pandas()
|
38 |
|
39 |
|
40 |
+
# function to canonicalize SMILES
|
41 |
+
def normalize_smiles(smi, canonical=True, isomeric=False):
|
42 |
+
try:
|
43 |
+
normalized = Chem.MolToSmiles(
|
44 |
+
Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
|
45 |
+
)
|
46 |
+
except:
|
47 |
+
normalized = None
|
48 |
+
return normalized
|
49 |
+
|
50 |
+
|
51 |
class MolTranBertTokenizer(BertTokenizer):
|
52 |
def __init__(self, vocab_file: str = '',
|
53 |
do_lower_case=False,
|
|
|
495 |
if self.is_cuda_available:
|
496 |
self.encoder.cuda()
|
497 |
self.decoder.cuda()
|
498 |
+
|
499 |
+
# handle single str or a list of str
|
500 |
+
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
|
501 |
+
|
502 |
+
# SMILES normalization
|
503 |
+
smiles = smiles.apply(normalize_smiles)
|
504 |
+
null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
|
505 |
+
smiles = smiles.dropna()
|
506 |
|
507 |
# tokenizer
|
508 |
+
idx, mask = self.tokenize(smiles.to_list())
|
509 |
|
510 |
###########
|
511 |
# Encoder #
|
|
|
541 |
# reconstruct tokens
|
542 |
pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
|
543 |
pred_ids = torch.argmax(pred_ids, axis=-1)
|
544 |
+
|
545 |
+
# replacing null SMILES with NaN values
|
546 |
+
for idx in null_idx:
|
547 |
+
true_ids = true_ids.tolist()
|
548 |
+
pred_ids = pred_ids.tolist()
|
549 |
+
true_cte = true_cte.tolist()
|
550 |
+
pred_cte = pred_cte.tolist()
|
551 |
+
true_set = true_set.tolist()
|
552 |
+
pred_set = pred_set.tolist()
|
553 |
+
|
554 |
+
true_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
|
555 |
+
pred_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
|
556 |
+
true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
|
557 |
+
pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
|
558 |
+
true_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
|
559 |
+
pred_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
|
560 |
+
|
561 |
+
if len(null_idx) > 0:
|
562 |
+
true_ids = torch.tensor(true_ids)
|
563 |
+
pred_ids = torch.tensor(pred_ids)
|
564 |
+
true_cte = torch.tensor(true_cte)
|
565 |
+
pred_cte = torch.tensor(pred_cte)
|
566 |
+
true_set = torch.tensor(true_set)
|
567 |
+
pred_set = torch.tensor(pred_set)
|
568 |
|
569 |
return ((true_ids, pred_ids), # tokens
|
570 |
(true_cte, pred_cte), # token embeddings
|
|
|
598 |
|
599 |
# handle single str or a list of str
|
600 |
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
|
601 |
+
|
602 |
+
# SMILES normalization
|
603 |
+
smiles = smiles.apply(normalize_smiles)
|
604 |
+
null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
|
605 |
+
smiles = smiles.dropna()
|
606 |
+
|
607 |
# process in batches
|
608 |
+
n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
|
609 |
embeddings = [
|
610 |
self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
|
611 |
for batch in tqdm(np.array_split(smiles, n_split))
|
|
|
617 |
torch.cuda.empty_cache()
|
618 |
gc.collect()
|
619 |
|
620 |
+
# replacing null SMILES with NaN values
|
621 |
+
for idx in null_idx:
|
622 |
+
flat_list.insert(idx, np.array([np.nan]*self.config['n_embd']))
|
623 |
+
flat_list = np.asarray(flat_list)
|
624 |
+
|
625 |
if return_torch:
|
626 |
+
return torch.tensor(flat_list)
|
627 |
return pd.DataFrame(flat_list)
|
628 |
|
629 |
def decode(self, smiles_embeddings):
|
|
|
667 |
):
|
668 |
tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
|
669 |
model = Smi_ted(tokenizer)
|
670 |
+
|
671 |
repo_id = "ibm/materials.smi-ted"
|
672 |
filename = "smi-ted-Light_40.pt"
|
673 |
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
|
675 |
model.eval()
|
676 |
print('Vocab size:', len(tokenizer.vocab))
|
677 |
print(f'[INFERENCE MODE - {str(model)}]')
|
678 |
+
return model
|
|
|
|