Spaces:
Runtime error
Runtime error
| import torch | |
| import math | |
| import torch.nn as nn | |
| from rdkit import Chem | |
| from rdkit import rdBase | |
| rdBase.DisableLog('rdApp.*') | |
| # Split SMILES into words | |
| def split(sm): | |
| ''' | |
| function: Split SMILES into words. Care for Cl, Br, Si, Se, Na etc. | |
| input: A SMILES | |
| output: A string with space between words | |
| ''' | |
| arr = [] | |
| i = 0 | |
| while i < len(sm)-1: | |
| if not sm[i] in ['%', 'C', 'B', 'S', 'N', 'R', 'X', 'L', 'A', 'M', \ | |
| 'T', 'Z', 's', 't', 'H', '+', '-', 'K', 'F']: | |
| arr.append(sm[i]) | |
| i += 1 | |
| elif sm[i]=='%': | |
| arr.append(sm[i:i+3]) | |
| i += 3 | |
| elif sm[i]=='C' and sm[i+1]=='l': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='C' and sm[i+1]=='a': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='C' and sm[i+1]=='u': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='B' and sm[i+1]=='r': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='B' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='B' and sm[i+1]=='a': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='B' and sm[i+1]=='i': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='S' and sm[i+1]=='i': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='S' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='S' and sm[i+1]=='r': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='N' and sm[i+1]=='a': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='N' and sm[i+1]=='i': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='R' and sm[i+1]=='b': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='R' and sm[i+1]=='a': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='X' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='L' and sm[i+1]=='i': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='A' and sm[i+1]=='l': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='A' and sm[i+1]=='s': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='A' and sm[i+1]=='g': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='A' and sm[i+1]=='u': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='M' and sm[i+1]=='g': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='M' and sm[i+1]=='n': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='T' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='Z' and sm[i+1]=='n': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='s' and sm[i+1]=='i': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='s' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='t' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='H' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='+' and sm[i+1]=='2': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='+' and sm[i+1]=='3': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='+' and sm[i+1]=='4': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='-' and sm[i+1]=='2': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='-' and sm[i+1]=='3': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='-' and sm[i+1]=='4': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='K' and sm[i+1]=='r': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| elif sm[i]=='F' and sm[i+1]=='e': | |
| arr.append(sm[i:i+2]) | |
| i += 2 | |
| else: | |
| arr.append(sm[i]) | |
| i += 1 | |
| if i == len(sm)-1: | |
| arr.append(sm[i]) | |
| return ' '.join(arr) | |
| # 活性化関数 | |
| class GELU(nn.Module): | |
| def forward(self, x): | |
| return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
| # 位置情報を考慮したFFN | |
| class PositionwiseFeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.1): | |
| super(PositionwiseFeedForward, self).__init__() | |
| self.w_1 = nn.Linear(d_model, d_ff) | |
| self.w_2 = nn.Linear(d_ff, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = GELU() | |
| def forward(self, x): | |
| return self.w_2(self.dropout(self.activation(self.w_1(x)))) | |
| # 正規化層 | |
| class LayerNorm(nn.Module): | |
| def __init__(self, features, eps=1e-6): | |
| super(LayerNorm, self).__init__() | |
| self.a_2 = nn.Parameter(torch.ones(features)) | |
| self.b_2 = nn.Parameter(torch.zeros(features)) | |
| self.eps = eps | |
| def forward(self, x): | |
| mean = x.mean(-1, keepdim=True) | |
| std = x.std(-1, keepdim=True) | |
| return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 | |
| class SublayerConnection(nn.Module): | |
| def __init__(self, size, dropout): | |
| super(SublayerConnection, self).__init__() | |
| self.norm = LayerNorm(size) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, sublayer): | |
| return x + self.dropout(sublayer(self.norm(x))) | |
| # Sample SMILES from probablistic distribution | |
| def sample(msms): | |
| ret = [] | |
| for msm in msms: | |
| ret.append(torch.multinomial(msm.exp(), 1).squeeze()) | |
| return torch.stack(ret) | |
| def validity(smiles): | |
| loss = 0 | |
| for sm in smiles: | |
| mol = Chem.MolFromSmiles(sm) | |
| if mol is None: | |
| loss += 1 | |
| return 1-loss/len(smiles) | |