vshirasuna commited on
Commit
159c02c
1 Parent(s): 32b373a

Added normalization of SMILES

Browse files
models/smi_ted/smi_ted_light/load.py CHANGED
@@ -19,6 +19,13 @@ from huggingface_hub import hf_hub_download
19
  # Data
20
  import numpy as np
21
  import pandas as pd
 
 
 
 
 
 
 
22
 
23
  # Standard library
24
  from functools import partial
@@ -30,6 +37,17 @@ from tqdm import tqdm
30
  tqdm.pandas()
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  class MolTranBertTokenizer(BertTokenizer):
34
  def __init__(self, vocab_file: str = '',
35
  do_lower_case=False,
@@ -477,9 +495,17 @@ class Smi_ted(nn.Module):
477
  if self.is_cuda_available:
478
  self.encoder.cuda()
479
  self.decoder.cuda()
 
 
 
 
 
 
 
 
480
 
481
  # tokenizer
482
- idx, mask = self.tokenize(smiles)
483
 
484
  ###########
485
  # Encoder #
@@ -515,6 +541,30 @@ class Smi_ted(nn.Module):
515
  # reconstruct tokens
516
  pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
517
  pred_ids = torch.argmax(pred_ids, axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
  return ((true_ids, pred_ids), # tokens
520
  (true_cte, pred_cte), # token embeddings
@@ -548,9 +598,14 @@ class Smi_ted(nn.Module):
548
 
549
  # handle single str or a list of str
550
  smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
551
- n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
552
-
 
 
 
 
553
  # process in batches
 
554
  embeddings = [
555
  self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
556
  for batch in tqdm(np.array_split(smiles, n_split))
@@ -562,8 +617,13 @@ class Smi_ted(nn.Module):
562
  torch.cuda.empty_cache()
563
  gc.collect()
564
 
 
 
 
 
 
565
  if return_torch:
566
- return torch.tensor(np.array(flat_list))
567
  return pd.DataFrame(flat_list)
568
 
569
  def decode(self, smiles_embeddings):
@@ -607,6 +667,7 @@ def load_smi_ted(folder="./smi_ted_light",
607
  ):
608
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
609
  model = Smi_ted(tokenizer)
 
610
  repo_id = "ibm/materials.smi-ted"
611
  filename = "smi-ted-Light_40.pt"
612
  file_path = hf_hub_download(repo_id=repo_id, filename=filename)
@@ -614,6 +675,4 @@ def load_smi_ted(folder="./smi_ted_light",
614
  model.eval()
615
  print('Vocab size:', len(tokenizer.vocab))
616
  print(f'[INFERENCE MODE - {str(model)}]')
617
- return model
618
-
619
-
 
19
  # Data
20
  import numpy as np
21
  import pandas as pd
22
+ import numpy as np
23
+
24
+ # Chemistry
25
+ from rdkit import Chem
26
+ from rdkit.Chem import PandasTools
27
+ from rdkit.Chem import Descriptors
28
+ PandasTools.RenderImagesInAllDataFrames(True)
29
 
30
  # Standard library
31
  from functools import partial
 
37
  tqdm.pandas()
38
 
39
 
40
+ # function to canonicalize SMILES
41
+ def normalize_smiles(smi, canonical=True, isomeric=False):
42
+ try:
43
+ normalized = Chem.MolToSmiles(
44
+ Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
45
+ )
46
+ except:
47
+ normalized = None
48
+ return normalized
49
+
50
+
51
  class MolTranBertTokenizer(BertTokenizer):
52
  def __init__(self, vocab_file: str = '',
53
  do_lower_case=False,
 
495
  if self.is_cuda_available:
496
  self.encoder.cuda()
497
  self.decoder.cuda()
498
+
499
+ # handle single str or a list of str
500
+ smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
501
+
502
+ # SMILES normalization
503
+ smiles = smiles.apply(normalize_smiles)
504
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
505
+ smiles = smiles.dropna()
506
 
507
  # tokenizer
508
+ idx, mask = self.tokenize(smiles.to_list())
509
 
510
  ###########
511
  # Encoder #
 
541
  # reconstruct tokens
542
  pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
543
  pred_ids = torch.argmax(pred_ids, axis=-1)
544
+
545
+ # replacing null SMILES with NaN values
546
+ for idx in null_idx:
547
+ true_ids = true_ids.tolist()
548
+ pred_ids = pred_ids.tolist()
549
+ true_cte = true_cte.tolist()
550
+ pred_cte = pred_cte.tolist()
551
+ true_set = true_set.tolist()
552
+ pred_set = pred_set.tolist()
553
+
554
+ true_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
555
+ pred_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
556
+ true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
557
+ pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
558
+ true_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
559
+ pred_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
560
+
561
+ if len(null_idx) > 0:
562
+ true_ids = torch.tensor(true_ids)
563
+ pred_ids = torch.tensor(pred_ids)
564
+ true_cte = torch.tensor(true_cte)
565
+ pred_cte = torch.tensor(pred_cte)
566
+ true_set = torch.tensor(true_set)
567
+ pred_set = torch.tensor(pred_set)
568
 
569
  return ((true_ids, pred_ids), # tokens
570
  (true_cte, pred_cte), # token embeddings
 
598
 
599
  # handle single str or a list of str
600
  smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
601
+
602
+ # SMILES normalization
603
+ smiles = smiles.apply(normalize_smiles)
604
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
605
+ smiles = smiles.dropna()
606
+
607
  # process in batches
608
+ n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
609
  embeddings = [
610
  self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
611
  for batch in tqdm(np.array_split(smiles, n_split))
 
617
  torch.cuda.empty_cache()
618
  gc.collect()
619
 
620
+ # replacing null SMILES with NaN values
621
+ for idx in null_idx:
622
+ flat_list.insert(idx, np.array([np.nan]*self.config['n_embd']))
623
+ flat_list = np.asarray(flat_list)
624
+
625
  if return_torch:
626
+ return torch.tensor(flat_list)
627
  return pd.DataFrame(flat_list)
628
 
629
  def decode(self, smiles_embeddings):
 
667
  ):
668
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
669
  model = Smi_ted(tokenizer)
670
+
671
  repo_id = "ibm/materials.smi-ted"
672
  filename = "smi-ted-Light_40.pt"
673
  file_path = hf_hub_download(repo_id=repo_id, filename=filename)
 
675
  model.eval()
676
  print('Vocab size:', len(tokenizer.vocab))
677
  print(f'[INFERENCE MODE - {str(model)}]')
678
+ return model