diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..5195c1705731d972c425d3cfc0977f8240b3f3bd
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,13 @@
+FROM python:3.9.7
+
+WORKDIR /app
+COPY requirements.txt .
+RUN pip install -r requirements.txt
+# preload models
+RUN python -c '\
+from transformers import BartForConditionalGeneration, AutoTokenizer;\
+AutoTokenizer.from_pretrained("ibm/materials.selfies-ted");\
+BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")'
+COPY . .
+
+CMD ["python", "app.py"]
\ No newline at end of file
diff --git a/Dockerfile-conda b/Dockerfile-conda
new file mode 100644
index 0000000000000000000000000000000000000000..98d0dfefb3c6b01820abeaea921c62a8246e8e6a
--- /dev/null
+++ b/Dockerfile-conda
@@ -0,0 +1,13 @@
+FROM condaforge/miniforge3
+
+WORKDIR /app
+SHELL ["/bin/bash", "-i", "-c"]
+RUN apt-get update && \
+ apt-get install -y build-essential libxrender1 libxext-dev
+RUN conda create --name fm4m python=3.9.7
+RUN conda activate fm4m
+COPY requirements.txt .
+RUN pip install -r requirements.txt
+COPY . .
+
+CMD ["python", "app.py"]
\ No newline at end of file
diff --git a/README.md b/README.md
index 6a59e2ce9c9b05ed6a611dc4c9165d04ab0ba522..ef50b227222b4d50684a8c88b96950448e2aeaa5 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
---
-title: Fm4m Kit
+title: Fix Fm4m Kit
emoji: 🐢
colorFrom: indigo
colorTo: blue
diff --git a/app.py b/app.py
index 97f591651f05a4719c289d31bd91e424e5becaab..97d96b99fd8f6f9bccb0d1b90af7db7c2fd93af5 100644
--- a/app.py
+++ b/app.py
@@ -1,142 +1,103 @@
import gradio as gr
-from huggingface_hub import InferenceClient
import matplotlib.pyplot as plt
-from PIL import Image
-from rdkit.Chem import Descriptors, QED, Draw
-from rdkit.Chem.Crippen import MolLogP
+import numpy as np
+import os
import pandas as pd
-from rdkit.Contrib.SA_Score import sascorer
-from rdkit.Chem import DataStructs, AllChem
-from transformers import BartForConditionalGeneration, AutoTokenizer, AutoModel
-from transformers.modeling_outputs import BaseModelOutput
+import re
import selfies as sf
-from rdkit import Chem
import torch
-import numpy as np
-import umap
-import pickle
import xgboost as xgb
-from sklearn.svm import SVR
-from sklearn.linear_model import LinearRegression
+from PIL import Image
+from rdkit import Chem, RDLogger
+from rdkit.Chem import DataStructs, AllChem, Descriptors, QED, Draw
+from rdkit.Chem.Crippen import MolLogP
+from rdkit.Contrib.SA_Score import sascorer
from sklearn.kernel_ridge import KernelRidge
-import json
-
-import os
+from sklearn.linear_model import LinearRegression
+from sklearn.svm import SVR
+from transformers import BartForConditionalGeneration, AutoTokenizer
+from transformers.modeling_outputs import BaseModelOutput
os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
-# my_theme = gr.Theme.from_hub("ysharma/steampunk")
-# my_theme = gr.themes.Glass()
-
-"""
-# カスタムテーマ設定
-theme = gr.themes.Default().set(
- body_background_fill="#000000", # 背景色を黒に設定
- text_color="#FFFFFF", # テキスト色を白に設定
-)
-"""
-"""
-import sys
-sys.path.append("models")
-sys.path.append("../models")
-sys.path.append("../")"""
-
-
-# Get the current file's directory
-base_dir = os.path.dirname(__file__)
-print("Base Dir : ", base_dir)
-
import models.fm4m as fm4m
+RDLogger.logger().setLevel(RDLogger.ERROR)
+
# Function to display molecule image from SMILES
def smiles_to_image(smiles):
mol = Chem.MolFromSmiles(smiles)
- if mol:
- img = Draw.MolToImage(mol)
- return img
- return None
-
-
-# Function to get canonical SMILES
-def get_canonical_smiles(smiles):
- mol = Chem.MolFromSmiles(smiles)
- if mol:
- return Chem.MolToSmiles(mol, canonical=True)
- return None
+ return Draw.MolToImage(mol) if mol else None
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
smiles_image_mapping = {
- "Mol 1": {"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", "image": "img/img1.png"},
+ "Mol 1": {
+ "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
+ "image": "img/img1.png",
+ },
# Example SMILES for ethanol
- "Mol 2": {"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", "image": "img/img2.png"},
+ "Mol 2": {
+ "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
+ "image": "img/img2.png",
+ },
# Example SMILES for butane
- "Mol 3": {"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
- "image": "img/img3.png"}, # Example SMILES for ethylamine
- "Mol 4": {"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", "image": "img/img4.png"},
+ "Mol 3": {
+ "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
+ "image": "img/img3.png",
+ }, # Example SMILES for ethylamine
+ "Mol 4": {
+ "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
+ "image": "img/img4.png",
+ },
# Example SMILES for diethyl ether
- "Mol 5": {"smiles": "C=CCS[C@@H](C)CC(=O)OCC", "image": "img/img5.png"} # Example SMILES for chloroethane
+ "Mol 5": {
+ "smiles": "C=CCS[C@@H](C)CC(=O)OCC",
+ "image": "img/img5.png",
+ }, # Example SMILES for chloroethane
}
datasets = [" ", "BACE", "ESOL", "Load Custom Dataset"]
-models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"]
+models_enabled = [
+ "SELFIES-TED",
+ "MHG-GED",
+ "MolFormer",
+ "SMI-TED",
+ "Mordred",
+ "MorganFingerprint",
+]
fusion_available = ["Concat"]
-global log_df
-log_df = pd.DataFrame(columns=["Selected Models", "Dataset", "Task", "Result"])
-
-
-def log_selection(models, dataset, task_type, result, log_df):
- # Append the new entry to the DataFrame
- new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_type, "Result": result}
- updated_log_df = log_df.append(new_entry, ignore_index=True)
- return updated_log_df
-
# Function to handle evaluation and logging
-def save_rep(models, dataset, task_type, eval_output):
- return
-def evaluate_and_log(models, dataset, task_type, eval_output):
+def evaluate_and_log(models, dataset, task_type, eval_output, state):
task_dic = {'Classification': 'CLS', 'Regression': 'RGR'}
- result = f"{eval_output}"#display_eval(models, dataset, task_type, fusion_type=None)
+ result = f"{eval_output}"
result = result.replace(" Score", "")
- new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_dic[task_type], "Result": result}
+ new_entry = {
+ "Selected Models": str(models),
+ "Dataset": dataset,
+ "Task": task_dic[task_type],
+ "Result": result,
+ }
new_entry_df = pd.DataFrame([new_entry])
- log_df = pd.read_csv('log.csv', index_col=0)
- log_df = pd.concat([new_entry_df, log_df])
-
- log_df.to_csv('log.csv')
-
- return log_df
-
-
-try:
- log_df = pd.read_csv('log.csv', index_col=0)
-except:
- log_df = pd.DataFrame({"":[],
- 'Selected Models': [],
- 'Dataset': [],
- 'Task': [],
- 'Result': []
- })
- csv_file_path = 'log.csv'
- log_df.to_csv(csv_file_path, index=False)
+ state["log_df"] = pd.concat([new_entry_df, state["log_df"]])
+ return state["log_df"]
# Load images for selection
def load_image(path):
try:
- return Image.open(smiles_image_mapping[path]["image"])# Image.1open(path)
+ return Image.open(smiles_image_mapping[path]["image"])
except:
pass
-
# Function to handle image selection
def handle_image_selection(image_key):
smiles = smiles_image_mapping[image_key]["smiles"]
@@ -160,49 +121,55 @@ def calculate_tanimoto(smiles1, smiles2):
mol1 = Chem.MolFromSmiles(smiles1)
mol2 = Chem.MolFromSmiles(smiles2)
if mol1 and mol2:
- # fp1 = FingerprintMols.FingerprintMol(mol1)
- # fp2 = FingerprintMols.FingerprintMol(mol2)
fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2)
fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2)
return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2)
return None
-#with open("models/selfies_model/bart-2908.pickle", "rb") as input_file:
-# gen_model, gen_tokenizer = pickle.load(input_file)
-
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")
def generate(latent_vector, mask):
encoder_outputs = BaseModelOutput(latent_vector)
- decoder_output = gen_model.generate(encoder_outputs=encoder_outputs, attention_mask=mask,
- max_new_tokens=64, do_sample=True, top_k=5, top_p=0.95, num_return_sequences=1)
+ decoder_output = gen_model.generate(
+ encoder_outputs=encoder_outputs,
+ attention_mask=mask,
+ max_new_tokens=64,
+ do_sample=True,
+ top_k=5,
+ top_p=0.95,
+ num_return_sequences=1,
+ )
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
- outs = []
- for i in selfies:
- outs.append(sf.decoder(i.replace("] [", "][")))
- return outs
+ return [sf.decoder(re.sub(r'\]\s*(.*?)\s*\[', r']\1[', i)) for i in selfies]
def perturb_latent(latent_vecs, noise_scale=0.5):
- modified_vec = torch.tensor(np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
- dtype=torch.float32) + latent_vecs
- return modified_vec
+ return (
+ torch.tensor(
+ np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
+ dtype=torch.float32,
+ )
+ + latent_vecs
+ )
def encode(selfies):
- encoding = gen_tokenizer(selfies, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
+ encoding = gen_tokenizer(
+ selfies,
+ return_tensors='pt',
+ max_length=128,
+ truncation=True,
+ padding='max_length',
+ )
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
- outputs = gen_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
+ outputs = gen_model.model.encoder(
+ input_ids=input_ids, attention_mask=attention_mask
+ )
model_output = outputs.last_hidden_state
-
- """input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
- sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
- model_output = sum_embeddings / sum_mask"""
return model_output, attention_mask
@@ -217,8 +184,13 @@ def generate_canonical(smiles):
noise = i / 10
perturbed_latent = perturb_latent(latent_vec, noise_scale=noise)
gen = generate(perturbed_latent, mask)
- gen_mol = Chem.MolToSmiles(Chem.MolFromSmiles(gen[0]))
- if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break
+ mol = Chem.MolFromSmiles(gen[0])
+ if mol:
+ gen_mol = Chem.MolToSmiles(mol)
+ if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
+ break
+ else:
+ print('Abnormal molecule:', gen[0])
if gen_mol:
# Calculate properties for ref and gen molecules
@@ -230,9 +202,20 @@ def generate_canonical(smiles):
# Prepare the table with ref mol and gen mol
data = {
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
- "Reference Mol": [ref_properties[0], ref_properties[1], ref_properties[2], ref_properties[3],
- tanimoto_similarity],
- "Generated Mol": [gen_properties[0], gen_properties[1], gen_properties[2], gen_properties[3], ""]
+ "Reference Mol": [
+ ref_properties[0],
+ ref_properties[1],
+ ref_properties[2],
+ ref_properties[3],
+ tanimoto_similarity,
+ ],
+ "Generated Mol": [
+ gen_properties[0],
+ gen_properties[1],
+ gen_properties[2],
+ gen_properties[3],
+ "",
+ ],
}
df = pd.DataFrame(data)
@@ -245,7 +228,7 @@ def generate_canonical(smiles):
# Function to display evaluation score
-def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
+def display_eval(selected_models, dataset, task_type, downstream, fusion_type, state):
result = None
try:
@@ -260,72 +243,87 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
downstream_model = downstream_model.rstrip()
params = None
-
-
-
try:
if not selected_models:
return "Please select at least one enabled model."
- if task_type == "Classification":
- global roc_auc, fpr, tpr, x_batch, y_batch
- elif task_type == "Regression":
- global RMSE, y_batch_test, y_prob
-
if len(selected_models) > 1:
if task_type == "Classification":
- #result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
- # downstream_model="XGBClassifier",
- # dataset=dataset.lower())
if downstream_model == "Default Settings":
downstream_model = "DefaultClassifier"
params = None
- result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
- downstream_model=downstream_model,
- params = params,
- dataset=dataset)
- elif task_type == "Regression":
- #result, RMSE, y_batch_test, y_prob = fm4m.multi_modal(model_list=selected_models,
- # downstream_model="XGBRegressor",
- # dataset=dataset.lower())
+ (
+ result,
+ state["roc_auc"],
+ state["fpr"],
+ state["tpr"],
+ state["x_batch"],
+ state["y_batch"],
+ ) = fm4m.multi_modal(
+ model_list=selected_models,
+ downstream_model=downstream_model,
+ params=params,
+ dataset=dataset,
+ )
+ elif task_type == "Regression":
if downstream_model == "Default Settings":
downstream_model = "DefaultRegressor"
params = None
- result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
- downstream_model=downstream_model,
- params=params,
- dataset=dataset)
+ (
+ result,
+ state["RMSE"],
+ state["y_batch_test"],
+ state["y_prob"],
+ state["x_batch"],
+ state["y_batch"],
+ ) = fm4m.multi_modal(
+ model_list=selected_models,
+ downstream_model=downstream_model,
+ params=params,
+ dataset=dataset,
+ )
else:
if task_type == "Classification":
- #result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
- # downstream_model="XGBClassifier",
- # dataset=dataset.lower())
if downstream_model == "Default Settings":
downstream_model = "DefaultClassifier"
params = None
- result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
- downstream_model=downstream_model,
- params=params,
- dataset=dataset)
+ (
+ result,
+ state["roc_auc"],
+ state["fpr"],
+ state["tpr"],
+ state["x_batch"],
+ state["y_batch"],
+ ) = fm4m.single_modal(
+ model=selected_models[0],
+ downstream_model=downstream_model,
+ params=params,
+ dataset=dataset,
+ )
elif task_type == "Regression":
- #result, RMSE, y_batch_test, y_prob = fm4m.single_modal(model=selected_models[0],
- # downstream_model="XGBRegressor",
- # dataset=dataset.lower())
-
if downstream_model == "Default Settings":
downstream_model = "DefaultRegressor"
params = None
- result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
- downstream_model=downstream_model,
- params=params,
- dataset=dataset)
+ (
+ result,
+ state["RMSE"],
+ state["y_batch_test"],
+ state["y_prob"],
+ state["x_batch"],
+ state["y_batch"],
+ ) = fm4m.single_modal(
+ model=selected_models[0],
+ downstream_model=downstream_model,
+ params=params,
+ dataset=dataset,
+ )
if result == None:
result = "Data & Model Setting is incorrect"
@@ -335,23 +333,15 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
# Function to handle plot display
-def display_plot(plot_type):
+def display_plot(plot_type, state):
fig, ax = plt.subplots()
if plot_type == "Latent Space":
- global x_batch, y_batch
+ x_batch, y_batch = state.get("x_batch"), state.get("y_batch")
ax.set_title("T-SNE Plot")
- # reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
- # features_umap = reducer.fit_transform(x_batch[:500])
- # x = y_batch.values[:500]
- # index_0 = [index for index in range(len(x)) if x[index] == 0]
- # index_1 = [index for index in range(len(x)) if x[index] == 1]
- class_0 = x_batch # features_umap[index_0]
- class_1 = y_batch # features_umap[index_1]
-
- """with open("latent_multi_bace.pkl", "rb") as f:
- class_0, class_1 = pickle.load(f)
- """
+ class_0 = x_batch
+ class_1 = y_batch
+
plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1')
plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0')
@@ -360,10 +350,16 @@ def display_plot(plot_type):
ax.set_title('Dataset Distribution')
elif plot_type == "ROC-AUC":
- global roc_auc, fpr, tpr
+ roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr")
ax.set_title("ROC-AUC Curve")
try:
- ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})')
+ ax.plot(
+ fpr,
+ tpr,
+ color='darkorange',
+ lw=2,
+ label=f'ROC curve (area = {roc_auc:.4f})',
+ )
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
@@ -375,7 +371,11 @@ def display_plot(plot_type):
ax.legend(loc='lower right')
elif plot_type == "Parity Plot":
- global RMSE, y_batch_test, y_prob
+ RMSE, y_batch_test, y_prob = (
+ state.get("RMSE"),
+ state.get("y_batch_test"),
+ state.get("y_prob"),
+ )
ax.set_title("Parity plot")
# change format
@@ -384,7 +384,12 @@ def display_plot(plot_type):
print(y_prob)
y_batch_test = np.array(y_batch_test, dtype=float)
y_prob = np.array(y_prob, dtype=float)
- ax.scatter(y_batch_test, y_prob, color="blue", label=f"Predicted vs Actual (RMSE: {RMSE:.4f})")
+ ax.scatter(
+ y_batch_test,
+ y_prob,
+ color="blue",
+ label=f"Predicted vs Actual (RMSE: {RMSE:.4f})",
+ )
min_val = min(min(y_batch_test), min(y_prob))
max_val = max(max(y_batch_test), max(y_prob))
ax.plot([min_val, max_val], [min_val, max_val], 'r-')
@@ -397,10 +402,6 @@ def display_plot(plot_type):
print(y_batch_test)
print(y_prob)
-
-
-
-
ax.set_xlabel('Actual Values')
ax.set_ylabel('Predicted Values')
@@ -419,13 +420,25 @@ predefined_datasets = {
# Function to load a predefined dataset from the local path
def load_predefined_dataset(dataset_name):
val = predefined_datasets.get(dataset_name)
- try: file_path = val.split(",")[0]
- except:file_path=False
+ try:
+ file_path = val.split(",")[0]
+ except:
+ file_path = False
if file_path:
df = pd.read_csv(file_path)
- return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns)), f"{dataset_name.lower()}"
- return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]), f"Dataset not found"
+ return (
+ df.head(),
+ gr.update(choices=list(df.columns)),
+ gr.update(choices=list(df.columns)),
+ f"{dataset_name.lower()}",
+ )
+ return (
+ pd.DataFrame(),
+ gr.update(choices=[]),
+ gr.update(choices=[]),
+ f"Dataset not found",
+ )
# Function to display the head of the uploaded CSV file
@@ -433,7 +446,11 @@ def display_csv_head(file):
if file is not None:
# Load the CSV file into a DataFrame
df = pd.read_csv(file.name)
- return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns))
+ return (
+ df.head(),
+ gr.update(choices=list(df.columns)),
+ gr.update(choices=list(df.columns)),
+ )
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
@@ -441,28 +458,54 @@ def display_csv_head(file):
def handle_dataset_selection(selected_dataset):
if selected_dataset == "Custom Dataset":
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
- visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
+ return (
+ gr.update(visible=True),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ )
else:
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(
- visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
+ return (
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ )
# Function to select input and output columns and display a message
-def select_columns(input_column, output_column, train_data, test_data,dataset_name):
+def select_columns(input_column, output_column, train_data, test_data, dataset_name):
if input_column and output_column:
return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
return "Please select both input and output columns."
-def set_dataname(dataset_name, dataset_selector ):
+
+def set_dataname(dataset_name, dataset_selector):
if dataset_selector == "Custom Dataset":
return f"{dataset_name}"
return f"{dataset_selector}"
+
# Function to create model based on user input
-def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None):
+def create_model(
+ model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None
+):
if model_name == "XGBClassifier":
- model = xgb.XGBClassifier(objective='binary:logistic',eval_metric= 'auc', max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
+ model = xgb.XGBClassifier(
+ objective='binary:logistic',
+ eval_metric='auc',
+ max_depth=max_depth,
+ n_estimators=n_estimators,
+ alpha=alpha,
+ )
elif model_name == "SVR":
model = SVR(degree=degree, kernel=kernel)
elif model_name == "Kernel Ridge":
@@ -476,224 +519,339 @@ def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degr
return "Model not supported."
return f"{model_name} * {model.get_params()}"
-def model_selector(model_name):
- # Dynamically return the appropriate hyperparameter components based on the selected model
- if model_name == "XGBClassifier":
- return (
- gr.Slider(1, 10, label="max_depth"),
- gr.Slider(50, 500, label="n_estimators"),
- gr.Slider(0.1, 10.0, step=0.1, label="alpha")
- )
- elif model_name == "SVR":
- return (
- gr.Slider(1, 5, label="degree"),
- gr.Dropdown(["rbf", "poly", "linear"], label="kernel")
- )
- elif model_name == "Kernel Ridge":
- return (
- gr.Slider(0.1, 10.0, step=0.1, label="alpha"),
- gr.Slider(1, 5, label="degree"),
- gr.Dropdown(["rbf", "poly", "linear"], label="kernel")
- )
- elif model_name == "Linear Regression":
- return () # No hyperparameters for Linear Regression
- else:
- return ()
-
# Define the Gradio layout
-# with gr.Blocks(theme=my_theme) as demo:
with gr.Blocks() as demo:
+ log_df = pd.DataFrame(
+ {"": [], 'Selected Models': [], 'Dataset': [], 'Task': [], 'Result': []}
+ )
+ state = gr.State({"log_df": log_df})
with gr.Row():
# Left Column
with gr.Column():
- gr.HTML('''
+ gr.HTML(
+ '''
Data & Model Setting
- ''')
- # gr.Markdown("## Data & Model Setting")
- #dataset_dropdown = gr.Dropdown(choices=datasets, label="Select Dat")
-
+ '''
+ )
# Dropdown menu for predefined datasets including "Custom Dataset" option
- dataset_selector = gr.Dropdown(label="Select Dataset",
- choices=list(predefined_datasets.keys()) + ["Custom Dataset"])
+ dataset_selector = gr.Dropdown(
+ label="Select Dataset",
+ choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
+ )
# Display the message for selected columns
- selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=False)
+ selected_columns_message = gr.Textbox(
+ label="Selected Columns Info", visible=False
+ )
with gr.Accordion("Dataset Settings", open=True):
# File upload options for custom dataset (train and test)
dataset_name = gr.Textbox(label="Dataset Name", visible=False)
- train_file = gr.File(label="Upload Custom Train Dataset", file_types=[".csv"], visible=False)
- train_display = gr.Dataframe(label="Train Dataset Preview (First 5 Rows)", visible=False, interactive=False)
+ train_file = gr.File(
+ label="Upload Custom Train Dataset",
+ file_types=[".csv"],
+ visible=False,
+ )
+ train_display = gr.Dataframe(
+ label="Train Dataset Preview (First 5 Rows)",
+ visible=False,
+ interactive=False,
+ )
- test_file = gr.File(label="Upload Custom Test Dataset", file_types=[".csv"], visible=False)
- test_display = gr.Dataframe(label="Test Dataset Preview (First 5 Rows)", visible=False, interactive=False)
+ test_file = gr.File(
+ label="Upload Custom Test Dataset",
+ file_types=[".csv"],
+ visible=False,
+ )
+ test_display = gr.Dataframe(
+ label="Test Dataset Preview (First 5 Rows)",
+ visible=False,
+ interactive=False,
+ )
# Predefined dataset displays
- predefined_display = gr.Dataframe(label="Predefined Dataset Preview (First 5 Rows)", visible=False,
- interactive=False)
-
-
+ predefined_display = gr.Dataframe(
+ label="Predefined Dataset Preview (First 5 Rows)",
+ visible=False,
+ interactive=False,
+ )
# Dropdowns for selecting input and output columns for the custom dataset
- input_column_selector = gr.Dropdown(label="Select Input Column", choices=[], visible=False)
- output_column_selector = gr.Dropdown(label="Select Output Column", choices=[], visible=False)
-
- #selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=True)
+ input_column_selector = gr.Dropdown(
+ label="Select Input Column", choices=[], visible=False
+ )
+ output_column_selector = gr.Dropdown(
+ label="Select Output Column", choices=[], visible=False
+ )
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
- dataset_selector.change(handle_dataset_selection,
- inputs=dataset_selector,
- outputs=[dataset_name, train_file, train_display, test_file, test_display, predefined_display,
- input_column_selector, output_column_selector])
+ dataset_selector.change(
+ handle_dataset_selection,
+ inputs=dataset_selector,
+ outputs=[
+ dataset_name,
+ train_file,
+ train_display,
+ test_file,
+ test_display,
+ predefined_display,
+ input_column_selector,
+ output_column_selector,
+ ],
+ )
# When a predefined dataset is selected, load its head and update column selectors
- dataset_selector.change(load_predefined_dataset,
- inputs=dataset_selector,
- outputs=[predefined_display, input_column_selector, output_column_selector, selected_columns_message])
+ dataset_selector.change(
+ load_predefined_dataset,
+ inputs=dataset_selector,
+ outputs=[
+ predefined_display,
+ input_column_selector,
+ output_column_selector,
+ selected_columns_message,
+ ],
+ )
# When a custom train file is uploaded, display its head and update column selectors
- train_file.change(display_csv_head, inputs=train_file,
- outputs=[train_display, input_column_selector, output_column_selector])
+ train_file.change(
+ display_csv_head,
+ inputs=train_file,
+ outputs=[
+ train_display,
+ input_column_selector,
+ output_column_selector,
+ ],
+ )
# When a custom test file is uploaded, display its head
- test_file.change(display_csv_head, inputs=test_file,
- outputs=[test_display, input_column_selector, output_column_selector])
+ test_file.change(
+ display_csv_head,
+ inputs=test_file,
+ outputs=[
+ test_display,
+ input_column_selector,
+ output_column_selector,
+ ],
+ )
- dataset_selector.change(set_dataname,
- inputs=[dataset_name, dataset_selector],
- outputs=dataset_name)
+ dataset_selector.change(
+ set_dataname,
+ inputs=[dataset_name, dataset_selector],
+ outputs=dataset_name,
+ )
# Update the selected columns information when dropdown values are changed
- input_column_selector.change(select_columns,
- inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
- outputs=selected_columns_message)
-
- output_column_selector.change(select_columns,
- inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
- outputs=selected_columns_message)
+ input_column_selector.change(
+ select_columns,
+ inputs=[
+ input_column_selector,
+ output_column_selector,
+ train_file,
+ test_file,
+ dataset_name,
+ ],
+ outputs=selected_columns_message,
+ )
- model_checkbox = gr.CheckboxGroup(choices=models_enabled, label="Select Model")
+ output_column_selector.change(
+ select_columns,
+ inputs=[
+ input_column_selector,
+ output_column_selector,
+ train_file,
+ test_file,
+ dataset_name,
+ ],
+ outputs=selected_columns_message,
+ )
- # Add disabled checkboxes for GNN and FNN
- # gnn_checkbox = gr.Checkbox(label="GNN (Disabled)", value=False, interactive=False)
- # fnn_checkbox = gr.Checkbox(label="FNN (Disabled)", value=False, interactive=False)
+ model_checkbox = gr.CheckboxGroup(
+ choices=models_enabled, label="Select Model"
+ )
- task_radiobutton = gr.Radio(choices=["Classification", "Regression"], label="Task Type")
+ task_radiobutton = gr.Radio(
+ choices=["Classification", "Regression"], label="Task Type"
+ )
####### adding hyper parameter tuning ###########
- model_name = gr.Dropdown(["Default - Auto", "XGBClassifier", "SVR", "Kernel Ridge", "Linear Regression"], label="Select Downstream Model")
+ model_name = gr.Dropdown(
+ [
+ "Default - Auto",
+ "XGBClassifier",
+ "SVR",
+ "Kernel Ridge",
+ "Linear Regression",
+ ],
+ label="Select Downstream Model",
+ )
with gr.Accordion("Downstream Hyperparameter Settings", open=True):
# Create placeholders for hyperparameter components
- max_depth = gr.Slider(1, 20, step=1,visible=False, label="max_depth")
- n_estimators = gr.Slider(100, 5000, step=100, visible=False, label="n_estimators")
+ max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth")
+ n_estimators = gr.Slider(
+ 100, 5000, step=100, visible=False, label="n_estimators"
+ )
alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
- degree = gr.Slider(1, 20, step=1,visible=False, label="degree")
- kernel = gr.Dropdown(choices=["rbf", "poly", "linear"], visible=False, label="kernel")
+ degree = gr.Slider(1, 20, step=1, visible=False, label="degree")
+ kernel = gr.Dropdown(
+ choices=["rbf", "poly", "linear"], visible=False, label="kernel"
+ )
# Output textbox
output = gr.Textbox(label="Loaded Parameters")
-
# Dynamically show relevant hyperparameters based on selected model
def update_hyperparameters(model_name):
if model_name == "XGBClassifier":
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
- visible=False), gr.update(visible=False)
+ return (
+ gr.update(visible=True),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ )
elif model_name == "SVR":
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
- visible=True), gr.update(visible=True)
+ return (
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ )
elif model_name == "Kernel Ridge":
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(
- visible=True), gr.update(visible=True)
+ return (
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ gr.update(visible=True),
+ )
elif model_name == "Linear Regression":
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
- visible=False), gr.update(visible=False)
+ return (
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ )
elif model_name == "Default - Auto":
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
- visible=False), gr.update(visible=False)
-
+ return (
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ gr.update(visible=False),
+ )
# When model is selected, update which hyperparameters are visible
- model_name.change(update_hyperparameters, inputs=[model_name],
- outputs=[max_depth, n_estimators, alpha, degree, kernel])
+ model_name.change(
+ update_hyperparameters,
+ inputs=[model_name],
+ outputs=[max_depth, n_estimators, alpha, degree, kernel],
+ )
# Submit button to create the model with selected hyperparameters
submit_button = gr.Button("Create Downstream Model")
-
# Function to handle model creation based on input parameters
def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel):
if model_name == "XGBClassifier":
- return create_model(model_name, max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
+ return create_model(
+ model_name,
+ max_depth=max_depth,
+ n_estimators=n_estimators,
+ alpha=alpha,
+ )
elif model_name == "SVR":
return create_model(model_name, degree=degree, kernel=kernel)
elif model_name == "Kernel Ridge":
- return create_model(model_name, alpha=alpha, degree=degree, kernel=kernel)
+ return create_model(
+ model_name, alpha=alpha, degree=degree, kernel=kernel
+ )
elif model_name == "Linear Regression":
return create_model(model_name)
elif model_name == "Default - Auto":
return create_model(model_name)
# When the submit button is clicked, run the on_submit function
- submit_button.click(on_submit, inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
- outputs=output)
+ submit_button.click(
+ on_submit,
+ inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
+ outputs=output,
+ )
###### End of hyper param tuning #########
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
-
-
eval_button = gr.Button("Train downstream model")
- #eval_button.style(css_class="custom-button-left")
# Middle Column
with gr.Column():
- gr.HTML('''
+ gr.HTML(
+ '''
Downstream Task 1: Property Prediction
- ''')
- # gr.Markdown("## Downstream task Result")
+ '''
+ )
eval_output = gr.Textbox(label="Train downstream model")
- plot_radio = gr.Radio(choices=["ROC-AUC", "Parity Plot", "Latent Space"], label="Select Plot Type")
- plot_output = gr.Plot(label="Visualization")#, height=250, width=250)
-
- #download_rep = gr.Button("Download representation")
+ plot_radio = gr.Radio(
+ choices=["ROC-AUC", "Parity Plot", "Latent Space"],
+ label="Select Plot Type",
+ )
+ plot_output = gr.Plot(label="Visualization")
create_log = gr.Button("Store log")
- log_table = gr.Dataframe(value=log_df, label="Log of Selections and Results", interactive=False)
-
- eval_button.click(display_eval,
- inputs=[model_checkbox, selected_columns_message, task_radiobutton, output, fusion_radiobutton],
- outputs=eval_output)
-
- plot_radio.change(display_plot, inputs=plot_radio, outputs=plot_output)
-
+ log_table = gr.Dataframe(
+ value=log_df, label="Log of Selections and Results", interactive=False
+ )
+
+ eval_button.click(
+ display_eval,
+ inputs=[
+ model_checkbox,
+ selected_columns_message,
+ task_radiobutton,
+ output,
+ fusion_radiobutton,
+ state,
+ ],
+ outputs=eval_output,
+ )
+
+ plot_radio.change(
+ display_plot, inputs=[plot_radio, state], outputs=plot_output
+ )
# Function to gather selected models
def gather_selected_models(*models):
selected = [model for model in models if model]
return selected
-
- create_log.click(evaluate_and_log, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
- outputs=log_table)
- #download_rep.click(save_rep, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
- # outputs=None)
-
+ create_log.click(
+ evaluate_and_log,
+ inputs=[
+ model_checkbox,
+ dataset_name,
+ task_radiobutton,
+ eval_output,
+ state,
+ ],
+ outputs=log_table,
+ )
# Right Column
with gr.Column():
- gr.HTML('''
+ gr.HTML(
+ '''
Downstream Task 2: Molecule Generation
- ''')
- # gr.Markdown("## Molecular Generation")
+ '''
+ )
smiles_input = gr.Textbox(label="Input SMILES String")
image_display = gr.Image(label="Molecule Image", height=250, width=250)
# Show images for selection
@@ -702,24 +860,32 @@ with gr.Blocks() as demo:
choices=list(smiles_image_mapping.keys()),
label="Select from sample molecules",
value=None,
- #item_images=[load_image(smiles_image_mapping[key]["image"]) for key in smiles_image_mapping.keys()]
)
image_selector.change(load_image, image_selector, image_display)
generate_button = gr.Button("Generate")
- gen_image_display = gr.Image(label="Generated Molecule Image", height=250, width=250)
+ gen_image_display = gr.Image(
+ label="Generated Molecule Image", height=250, width=250
+ )
generated_output = gr.Textbox(label="Generated Output")
property_table = gr.Dataframe(label="Molecular Properties Comparison")
-
-
# Handle image selection
- image_selector.change(handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display])
- smiles_input.change(smiles_to_image, inputs=smiles_input, outputs=image_display)
+ image_selector.change(
+ handle_image_selection,
+ inputs=image_selector,
+ outputs=[smiles_input, image_display],
+ )
+ smiles_input.change(
+ smiles_to_image, inputs=smiles_input, outputs=image_display
+ )
# Generate button to display canonical SMILES and molecule image
- generate_button.click(generate_canonical, inputs=smiles_input,
- outputs=[property_table, generated_output, gen_image_display])
+ generate_button.click(
+ generate_canonical,
+ inputs=smiles_input,
+ outputs=[property_table, generated_output, gen_image_display],
+ )
if __name__ == "__main__":
- demo.launch(share=True)
+ demo.launch(server_name="0.0.0.0")
diff --git a/data/lce/test.csv b/data/lce/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..95272b1fce743d785b366bce48b49270184452b7
--- /dev/null
+++ b/data/lce/test.csv
@@ -0,0 +1,31 @@
+smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE
+C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0.0,O,0.0,O,0.0,O,0.0,1.629
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,1.085
+COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0.0,O,0.0,O,0.0,2.056
+COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],,O,0.0,O,0.0,1.658
+C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0.0,1.638
+C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.276
+O1CCOC1,0.368,COCCOC,0.547,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.076,CSi(C)(C)([N+]).C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.008,O,0.0,O,0.0,1.569
+COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0.0,O,0.0,O,0.0,2.268
+C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,1.602
+C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0.0,1.678
+O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0.0,O,0.0,2.0
+C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0.0,O,0.0,O,0.0,0.921
+C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0.0,O,0.0,O,0.0,1.301
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0.0,O,0.0,0.854
+C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0.0,O,0.0,1.108
+O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0.0,O,0.0,O,0.0,1.523
+CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0.0,O,0.0,O,0.0,O,0.0,1.921
+CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0.0,O,0.0,O,0.0,O,0.0,1.602
+O1CCOC1,0.375,COCCOC,0.557,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.061,O,0.0,1.523
+COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0.0,O,0.0,O,0.0,2.155
+C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.26
+CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0.0,O,0.0,O,0.0,2.155
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.225
+COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155
+COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0.0,O,0.0,O,0.0,2.155
+O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0.0,O,0.0,O,0.0,2.301
+COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0.0,O,0.0,O,0.0,O,0.0,1.991
+COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,2.301
+C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0.0,O,0.0,O,0.0,1.398
+COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0.0,O,0.0,O,0.0,1.268
diff --git a/data/lce/test_data.csv b/data/lce/test_data.csv
new file mode 100644
index 0000000000000000000000000000000000000000..ddbd100e66ad449ecfaf7026b091fb84ded3fce8
--- /dev/null
+++ b/data/lce/test_data.csv
@@ -0,0 +1,14 @@
+smiles1,conc1,mol1,smiles2,conc2,mol2,smiles3,conc3,mol3,smiles4,conc4,mol4,smiles5,conc5,mol5,smiles6,conc6,LCE_Predicted,LCE
+C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.187,1.094
+COCCOC,0.596,59.5609428,COCCOCCOCCOCCOC,0.281,28.07124115,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.124,12.36781605,O,0,0,O,0,0,O,0,1.691,1.384
+C1COC(=O)O1,0.285,28.50894036,C1C(OC(=O)O1)F,0.261,26.07552384,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.228,22.82322096,COC(=O)OC,0.226,22.59231484,O,0,0,O,0,1.508,1.468
+COCCOC,0.434,43.4423376,COCCOCCOCCOCCOC,0.205,20.47449683,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.361,36.08316557,O,0,0,O,0,0,O,0,1.882,1.71
+C1C(OC(=O)O1)F,0.187,18.72872664,COC(=O)OC,0.162,16.22691423,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.109,10.92850826,FC(F)C(F)(F)COC(F)(F)C(F)F,0.541,54.11585087,O,0,0,O,0,2.103,1.832
+C1COC(=O)O1,0.134,13.35070843,C1C(OC(=O)O1)F,0.122,12.2111419,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.107,10.72028474,COC(=O)OC,0.106,10.57995858,FC(F)C(F)(F)COC(F)(F)C(F)F,0.531,53.13790635,O,0,2.077,2.104
+COCCOC,0.096,9.614613177,COCCOCCOCCOCCOC,0.045,4.53139444,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.12,12.01491409,C1COCO1,0.143,14.28400162,FC(F)C(F)(F)COC(F)(F)C(F)F,0.596,59.55507668,O,0,2.211,2.274
+C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].F[P-](F)(F)(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.17,1.071
+C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.077,1.166
+C1COC(=O)O1,0.519,51.85215842,COC(=O)OC,0.411,41.09097965,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.918492083,[Li+].[N+](=O)([O-])[O-],0.001,0.138369842,O,0,0,O,0,1.19,1.335
+C1COC(=O)O1,0.513,51.33049845,COC(=O)OC,0.407,40.6775828,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.9173773,C1=COC(=O)O1,0.011,1.07454145,O,0,0,O,0,1.114,1.129
+COCCOC,0.53,53.00533987,COCCOCCOCCOCCOC,0.25,24.98156691,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.22,22.01309322,O,0,0,O,0,0,O,0,1.758,1.501
+COCCOC,0.477,47.74974224,COCCOCCOCCOCCOC,0.225,22.50458884,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.297,29.74566892,O,0,0,O,0,0,O,0,1.821,1.663
\ No newline at end of file
diff --git a/data/lce/train.csv b/data/lce/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..3ba4d26d7016d2390f934922abf3cd650f734da9
--- /dev/null
+++ b/data/lce/train.csv
@@ -0,0 +1,121 @@
+smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE
+C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,1.155
+C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.046
+O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0.0,O,0.0,O,0.0,O,0.0,1.569
+C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0.0,O,0.0,O,0.0,0.886
+COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0.0,O,0.0,O,0.0,O,0.0,1.367
+COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,2.301
+C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0.0,O,0.0,O,0.0,O,0.0,1.489
+COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0.0,O,0.0,O,0.0,1.244
+C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0.0,O,0.0,0.745
+COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0.0,O,0.0,O,0.0,1.292
+CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0.0,O,0.0,O,0.0,2.301
+O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,O,0.0,1.745
+COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,1.745
+C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0.0,O,0.0,1.076
+C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.854
+C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0.0,1.678
+FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155
+CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222
+COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,O,0.0,1.638
+CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0.0,O,0.0,O,0.0,O,0.0,2.0
+COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,1.854
+O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0.0,O,0.0,O,0.0,1.959
+C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0.0,O,0.0,O,0.0,O,0.0,1.587
+CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0.0,O,0.0,O,0.0,0.699
+C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0.0,O,0.0,2.097
+C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0.0,O,0.0,O,0.0,1.59
+C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252
+COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,1.337
+C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0.0,O,0.0,O,0.0,O,0.0,1.377
+C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0.0,O,0.0,O,0.0,O,0.0,1.544
+CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0.0,O,0.0,O,0.0,O,0.0,2.097
+COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0.0,O,0.0,O,0.0,1.215
+COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0.0,O,0.0,O,0.0,O,0.0,1.222
+C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.194
+O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0.0,O,0.0,1.824
+C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.333
+O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0.0,O,0.0,1.824
+COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0.0,O,0.0,O,0.0,2.051
+COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0.0,O,0.0,O,0.0,1.444
+O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0.0,O,0.0,1.854
+CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0.0,O,0.0,O,0.0,O,0.0,2.046
+C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0.0,O,0.0,1.301
+C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222
+O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,1.903
+COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.561
+C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0.0,O,0.0,1.735
+FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0.0,O,0.0,O,0.0,O,0.0,2.301
+C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.498
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0.0,O,0.0,0.745
+O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0.0,O,0.0,O,0.0,1.824
+CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,O,0.0,2.0
+O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0.0,O,0.0,O,0.0,1.456
+COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0.0,O,0.0,O,0.0,O,0.0,1.301
+COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0.0,O,0.0,O,0.0,1.678
+C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0.0,O,0.0,O,0.0,O,0.0,1.646
+O1CCOC1,0.397,COCCOC,0.589,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.012,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.002,O,0.0,1.301
+C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0.0,O,0.0,O,0.0,2.046
+C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0.0,O,0.0,O,0.0,0.788
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.373
+O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0.0,O,0.0,O,0.0,O,0.0,1.602
+CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,1.854
+COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0.0,O,0.0,O,0.0,O,0.0,2.097
+O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0.0,O,0.0,O,0.0,O,0.0,1.699
+FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155
+CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0.0,O,0.0,O,0.0,O,0.0,2.208
+COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,1.77
+CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0.0,O,0.0,O,0.0,O,0.0,0.824
+C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0.0,O,0.0,0.924
+CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,2.097
+COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0.0,O,0.0,O,0.0,O,0.0,2.108
+CC1COC(=O)O1,0.922,[LI+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0.0,O,0.0,O,0.0,O,0.0,0.712
+C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0.0,O,0.0,1.081
+C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0.0,O,0.0,1.319
+COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0.0,O,0.0,O,0.0,1.62
+C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0.0,O,0.0,2.222
+C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.777
+CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0.0,O,0.0,O,0.0,O,0.0,2.018
+COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0.0,O,0.0,O,0.0,O,0.0,1.886
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0.0,O,0.0,0.699
+CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0.0,O,0.0,O,0.0,O,0.0,1.569
+C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0.0,O,0.0,O,0.0,1.523
+COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155
+C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.488
+O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0.0,2.046
+C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0.0,O,0.0,O,0.0,O,0.0,1.41
+COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0.0,O,0.0,O,0.0,2.222
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.559
+COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.301
+CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,O,0.0,1.672
+C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538
+CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0.0,O,0.0,O,0.0,O,0.0,1.796
+C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.355
+C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0.0,O,0.0,O,0.0,1.523
+COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0.0,O,0.0,1.78
+O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0.0,O,0.0,O,0.0,O,0.0,1.456
+O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0.0,O,0.0,O,0.0,O,0.0,1.745
+O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0.0,O,0.0,O,0.0,1.967
+COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0.0,O,0.0,O,0.0,O,0.0,2.097
+COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0.0,O,0.0,O,0.0,O,0.0,1.143
+O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0.0,O,0.0,1.523
+COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0.0,O,0.0,O,0.0,2.301
+CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,2.155
+C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,1.301
+COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.222
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0.0,O,0.0,0.699
+COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,1.495
+C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0.0,O,0.0,O,0.0,2.155
+C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0.0,O,0.0,1.921
+COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0.0,O,0.0,1.886
+CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,2.046
+COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,O,0.0,1.745
+C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0.0,O,0.0,O,0.0,O,0.0,1.633
+C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0.0,O,0.0,O,0.0,2.097
+FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.301
+C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0.0,O,0.0,1.108
+C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0.0,O,0.0,O,0.0,1.62
+CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.959
+C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,O,0.0,1.013
+C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0.0,1.824
+O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0.0,O,0.0,O,0.0,1.921
diff --git a/data/lce/train_data.csv b/data/lce/train_data.csv
new file mode 100644
index 0000000000000000000000000000000000000000..26cdcb3434b884dde32d05a77c2f112c72214680
--- /dev/null
+++ b/data/lce/train_data.csv
@@ -0,0 +1,148 @@
+smiles1,conc1,smiles2,conc2,smiles3,conc3,smiles4,conc4,smiles5,conc5,smiles6,conc6,LCE
+CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0,O,0,O,0,0.699
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0,O,0,0.699
+FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0,O,0,O,0,O,0,2.301
+FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.301
+CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0,O,0,O,0,2.155
+COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155
+CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,2.155
+O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0,O,0,O,0,1.967
+COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0,O,0,O,0,O,0,1.991
+C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0,O,0,O,0,O,0,1.646
+COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],0.035,O,0,O,0,1.658
+CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,O,0,1.672
+C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252
+C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0,O,0,O,0,2.155
+COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0,O,0,O,0,2.155
+COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0,O,0,O,0,2.155
+FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155
+FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155
+CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0,O,0,O,0,O,0,2.208
+C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0,O,0,2.222
+CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.054,O,0,O,0,O,0,2.222
+C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.029,O,0,O,0,O,0,2.222
+COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0,O,0,O,0,2.222
+COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.222
+COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0,O,0,O,0,2.268
+CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0,O,0,O,0,2.301
+COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,2.301
+COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0,O,0,O,0,2.301
+O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0,O,0,O,0,2.301
+COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,2.301
+COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.301
+O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0,O,0,2
+CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,O,0,2
+CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0,O,0,O,0,O,0,2.018
+CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0,O,0,O,0,O,0,2.046
+C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0,O,0,O,0,2.046
+O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0,2.046
+CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,2.046
+COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0,O,0,O,0,2.051
+COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0,O,0,O,0,2.056
+CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,2.097
+COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0,O,0,O,0,O,0,2.097
+C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0,O,0,O,0,2.097
+C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0,O,0,2.097
+CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0,O,0,O,0,O,0,2.097
+COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0,O,0,O,0,O,0,2.097
+COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0,O,0,O,0,O,0,2.108
+COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155
+CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0,O,0,O,0,O,0,2
+COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0,O,0,O,0,1.678
+C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0,1.678
+C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0,1.678
+O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0,O,0,O,0,O,0,1.699
+C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0,O,0,1.735
+O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,O,0,1.745
+COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,1.745
+COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,O,0,1.745
+O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0,O,0,O,0,O,0,1.745
+COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,1.77
+COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0,O,0,1.78
+CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0,O,0,O,0,O,0,1.796
+C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0,1.824
+O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0,O,0,O,0,1.824
+O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0,O,0,1.824
+O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0,O,0,1.824
+COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,1.854
+O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0,O,0,1.854
+CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,1.854
+COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0,O,0,1.886
+COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0,O,0,O,0,O,0,1.886
+O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,1.903
+O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0,O,0,O,0,1.921
+C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0,O,0,1.921
+CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0,O,0,O,0,O,0,1.921
+O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0,O,0,O,0,1.959
+CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.959
+C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0,O,0,O,0,O,0,1.377
+C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0,O,0,O,0,1.398
+C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0,O,0,O,0,O,0,1.41
+COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0,O,0,O,0,1.444
+O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0,O,0,O,0,O,0,1.456
+O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0,O,0,O,0,1.456
+C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.488
+C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0,O,0,O,0,O,0,1.489
+COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,1.495
+C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.498
+C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0,O,0,O,0,1.523
+O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0,O,0,O,0,1.523
+O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0,O,0,1.523
+C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0,O,0,O,0,1.523
+C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538
+C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0,O,0,O,0,O,0,1.544
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.559
+COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.561
+CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0,O,0,O,0,O,0,1.569
+O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0,O,0,O,0,O,0,1.569
+C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0,O,0,O,0,O,0,1.587
+C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0,O,0,O,0,1.59
+C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,1.602
+CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0,O,0,O,0,O,0,1.602
+O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0,O,0,O,0,O,0,1.602
+C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0,O,0,O,0,1.62
+COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0,O,0,O,0,1.62
+C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0,O,0,O,0,O,0,1.629
+C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0,O,0,O,0,O,0,1.633
+COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,O,0,1.638
+C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0,1.638
+C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.26
+COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0,O,0,O,0,1.268
+C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.276
+COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0,O,0,O,0,1.292
+C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0,O,0,1.301
+COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0,O,0,O,0,O,0,1.301
+C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0,O,0,O,0,1.301
+C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,1.301
+C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0,O,0,1.319
+C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.333
+COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,1.337
+C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.355
+COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0,O,0,O,0,O,0,1.367
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.373
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0,O,0,0.699
+CC1COC(=O)O1,0.922,[Li+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0,O,0,O,0,O,0,0.712
+C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0,O,0,0.745
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0,O,0,0.745
+C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.777
+C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0,O,0,O,0,0.788
+CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0,O,0,O,0,O,0,0.824
+C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0,O,0,0.854
+C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.854
+C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0,O,0,O,0,0.886
+C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0,O,0,O,0,0.921
+C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0,O,0,0.924
+C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,O,0,1.013
+C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.046
+C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0,O,0,1.076
+C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0,O,0,1.081
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,1.085
+C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0,O,0,1.108
+C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0,O,0,1.108
+COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0,O,0,O,0,O,0,1.143
+C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,1.155
+C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.194
+COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0,O,0,O,0,1.215
+COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0,O,0,O,0,O,0,1.222
+C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.225
+COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0,O,0,O,0,1.244
\ No newline at end of file
diff --git a/models/.gitattributes b/models/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..7aa4044c51cb3d662ba09fbc6be3c5a681e8e99f
--- /dev/null
+++ b/models/.gitattributes
@@ -0,0 +1,3 @@
+*.csv filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.pdf filter=lfs diff=lfs merge=lfs -text
diff --git a/models/fm4m.py b/models/fm4m.py
index 8f6fb86431768a45bfe1d32cae6498649d6fb385..15d98be8fb2f261a2a68cd340ae81fd03f0982e5 100644
--- a/models/fm4m.py
+++ b/models/fm4m.py
@@ -25,9 +25,17 @@ from sklearn.preprocessing import MinMaxScaler
import torch
from transformers import AutoTokenizer, AutoModel
-from .selfies_model.load import SELFIES as bart
-from .mhg_model import load as mhg
-from .smi_ted.smi_ted_light.load import load_smi_ted
+import sys
+sys.path.append("models/")
+
+from models.selfies_ted.load import SELFIES as bart
+from models.mhg_model import load as mhg
+from models.smi_ted.smi_ted_light.load import load_smi_ted
+
+import mordred
+from mordred import Calculator, descriptors
+from rdkit import Chem
+from rdkit.Chem import AllChem
datasets = {}
models = {}
@@ -48,7 +56,7 @@ def avail_models_data():
models = [{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality", "Timestamp": "2024-06-21 12:32:20"},
- {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
+ {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
{"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
@@ -58,8 +66,10 @@ def avail_models(raw=False):
models = [{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model"},
{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
- {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality"},
+ {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality"},
{"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
+ {"Name": "Mordred", "Model Name": "Mordred","Description": "Baseline: A descriptor-calculation software application that can calculate more than 1800 two- and three-dimensional descriptors"},
+ {"Name": "MorganFingerprint", "Model Name": "MorganFingerprint","Description": "Baseline: Circular atom environments based descriptor"}
]
@@ -70,12 +80,22 @@ def avail_models(raw=False):
return models
-def avail_downstream_models():
+def avail_downstream_models(raw=False):
global downstream_models
- with open("downstream_models.json", "r") as outfile:
- downstream_models = json.load(outfile)
- return downstream_models
+ downstream_models = [{"Name": "XGBClassifier", "Task Type": "Classfication"},
+ {"Name": "DefaultClassifier", "Task Type": "Classfication"},
+ {"Name": "SVR", "Task Type": "Regression"},
+ {"Name": "Kernel Ridge", "Task Type": "Regression"},
+ {"Name": "Linear Regression", "Task Type": "Regression"},
+ {"Name": "DefaultRegressor", "Task Type": "Regression"},
+ ]
+
+ if raw: return downstream_models
+ else:
+ return pd.DataFrame(downstream_models)
+
+
def avail_datasets():
global datasets
@@ -178,13 +198,15 @@ def update_downstream_model_list(list_model):
avail_models_data()
+
+
def get_representation(train_data,test_data,model_type, return_tensor=True):
alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
if model_type in alias.keys():
model_type = alias[model_type]
if model_type == "mhg":
- model = mhg.load("models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle")
+ model = mhg.load("../models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle")
with torch.no_grad():
train_emb = model.encode(train_data)
x_batch = torch.stack(train_emb)
@@ -196,7 +218,6 @@ def get_representation(train_data,test_data,model_type, return_tensor=True):
x_batch_test = pd.DataFrame(x_batch_test)
-
elif model_type == "bart":
model = bart()
model.load()
@@ -204,7 +225,7 @@ def get_representation(train_data,test_data,model_type, return_tensor=True):
x_batch_test = model.encode(test_data, return_tensor=return_tensor)
elif model_type == "smi-ted":
- model = load_smi_ted(folder='./models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
+ model = load_smi_ted(folder='../models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
with torch.no_grad():
x_batch = model.encode(train_data, return_torch=return_tensor)
x_batch_test = model.encode(test_data, return_torch=return_tensor)
@@ -237,35 +258,78 @@ def get_representation(train_data,test_data,model_type, return_tensor=True):
if not return_tensor:
x_batch = pd.DataFrame(x_batch)
x_batch_test = pd.DataFrame(x_batch_test)
-
+
+ elif model_type == 'Mordred':
+ all_data = train_data + test_data
+ calc = Calculator(descriptors, ignore_3D=True)
+ mol_list = [Chem.MolFromSmiles(sm) for sm in all_data]
+ x_all = calc.pandas(mol_list)
+ print (f'original mordred fv dim: {x_all.shape}')
+
+ for j in x_all.columns:
+ for k in range(len(x_all[j])):
+ i = x_all.loc[k, j]
+ if type(i) is mordred.error.Missing or type(i) is mordred.error.Error:
+ x_all.loc[k, j] = np.nan
+
+ x_all.dropna(how="any", axis = 1, inplace=True)
+ print (f'Nan excluded mordred fv dim: {x_all.shape}')
+
+ x_batch = x_all.iloc[:len(train_data)]
+ x_batch_test = x_all.iloc[len(train_data):]
+ # print(f'x_batch: {len(x_batch)}, x_batch_test: {len(x_batch_test)}')
+
+ elif model_type == 'MorganFingerprint':
+ params = {'radius':2, 'nBits':1024}
+
+ mol_train = [Chem.MolFromSmiles(sm) for sm in train_data]
+ mol_test = [Chem.MolFromSmiles(sm) for sm in test_data]
+
+ x_batch = []
+ for mol in mol_train:
+ info = {}
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info)
+ vector = list(fp)
+ x_batch.append(vector)
+ x_batch = pd.DataFrame(x_batch)
+
+ x_batch_test = []
+ for mol in mol_test:
+ info = {}
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info)
+ vector = list(fp)
+ x_batch_test.append(vector)
+ x_batch_test = pd.DataFrame(x_batch_test)
return x_batch, x_batch_test
-def single_modal(model,dataset, downstream_model,params):
+def single_modal(model,dataset=None, downstream_model=None, params=None, x_train=None, x_test=None, y_train=None, y_test=None):
print(model)
- alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED": "smi-ted"}
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
data = avail_models(raw=True)
df = pd.DataFrame(data)
- print(list(df["Name"].values))
- if alias[model] in list(df["Name"].values):
- if model in alias.keys():
+ #print(list(df["Name"].values))
+
+ if model in list(df["Name"].values):
+ model_type = model
+ elif alias[model] in list(df["Name"].values):
model_type = alias[model]
- else:
- model_type = model
else:
print("Model not available")
return
+
data = avail_datasets()
df = pd.DataFrame(data)
- print(list(df["Dataset"].values))
+ #print(list(df["Dataset"].values))
if dataset in list(df["Dataset"].values):
task = dataset
- with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
+ with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
print(f" Representation loaded successfully")
- else:
+
+ elif x_train==None:
print("Custom Dataset")
#return
@@ -283,14 +347,40 @@ def single_modal(model,dataset, downstream_model,params):
print(f" Representation loaded successfully")
+ else:
-
-
+ y_batch = y_train
+ y_batch_test = y_test
+ x_batch, x_batch_test = get_representation(x_train, x_test, model_type)
+
+ # exclude row containing Nan value
+ if isinstance(x_batch, torch.Tensor):
+ x_batch = pd.DataFrame(x_batch)
+ nan_indices = x_batch.index[x_batch.isna().any(axis=1)]
+ if len(nan_indices) > 0:
+ x_batch.dropna(inplace = True)
+ for index in sorted(nan_indices, reverse=True):
+ del y_batch[index]
+ print(f'x_batch Nan index: {nan_indices}')
+ print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}')
+
+ if isinstance(x_batch_test, torch.Tensor):
+ x_batch_test = pd.DataFrame(x_batch_test)
+ nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)]
+ if len(nan_indices) > 0:
+ x_batch_test.dropna(inplace = True)
+ for index in sorted(nan_indices, reverse=True):
+ del y_batch_test[index]
+ print(f'x_batch_test Nan index: {nan_indices}')
+ print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
print(f" Calculating ROC AUC Score ...")
if downstream_model == "XGBClassifier":
- xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
+ if params == None:
+ xgb_predict_concat = XGBClassifier()
+ else:
+ xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
xgb_predict_concat.fit(x_batch, y_batch)
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
@@ -300,21 +390,26 @@ def single_modal(model,dataset, downstream_model,params):
print(f"ROC-AUC Score: {roc_auc:.4f}")
try:
- with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
+ with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
class_0,class_1 = pickle.load(f1)
except:
print("Generating latent plots")
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
verbose=False)
n_samples = np.minimum(1000, len(x_batch))
- features_umap = reducer.fit_transform(x_batch[:n_samples])
+
try:x = y_batch.values[:n_samples]
except: x = y_batch[:n_samples]
index_0 = [index for index in range(len(x)) if x[index] == 0]
index_1 = [index for index in range(len(x)) if x[index] == 1]
- class_0 = features_umap[index_0]
- class_1 = features_umap[index_1]
+ try:
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
+ class_0 = features_umap[index_0]
+ class_1 = features_umap[index_1]
+ except:
+ class_0 = []
+ class_1 = []
print("Generating latent plots : Done")
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
@@ -334,20 +429,29 @@ def single_modal(model,dataset, downstream_model,params):
print(f"ROC-AUC Score: {roc_auc:.4f}")
try:
- with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
+ with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
class_0,class_1 = pickle.load(f1)
except:
print("Generating latent plots")
reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
n_samples = np.minimum(1000,len(x_batch))
- features_umap = reducer.fit_transform(x_batch[:n_samples])
- try:x = y_batch.values[:n_samples]
- except:x = y_batch[:n_samples]
- index_0 = [index for index in range(len(x)) if x[index] == 0]
- index_1 = [index for index in range(len(x)) if x[index] == 1]
- class_0 = features_umap[index_0]
- class_1 = features_umap[index_1]
+ try:
+ x = y_batch.values[:n_samples]
+ except:
+ x = y_batch[:n_samples]
+
+ try:
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
+
+ class_0 = features_umap[index_0]
+ class_1 = features_umap[index_1]
+ except:
+ class_0 = []
+ class_1 = []
+
print("Generating latent plots : Done")
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
@@ -355,16 +459,19 @@ def single_modal(model,dataset, downstream_model,params):
result = f"ROC-AUC Score: {roc_auc:.4f}"
return result, roc_auc,fpr, tpr, class_0, class_1
-
+
elif downstream_model == "SVR":
- regressor = SVR(**params)
+ if params == None:
+ regressor = SVR()
+ else:
+ regressor = SVR(**params)
model = TransformedTargetRegressor(regressor= regressor,
transformer = MinMaxScaler(feature_range=(-1, 1))
).fit(x_batch,y_batch)
-
+
y_prob = model.predict(x_batch_test)
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
-
+
print(f"RMSE Score: {RMSE_score:.4f}")
result = f"RMSE Score: {RMSE_score:.4f}"
@@ -372,20 +479,28 @@ def single_modal(model,dataset, downstream_model,params):
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
verbose=False)
n_samples = np.minimum(1000, len(x_batch))
- features_umap = reducer.fit_transform(x_batch[:n_samples])
- try:x = y_batch.values[:n_samples]
- except:x = y_batch[:n_samples]
+
+ try: x = y_batch.values[:n_samples]
+ except: x = y_batch[:n_samples]
#index_0 = [index for index in range(len(x)) if x[index] == 0]
#index_1 = [index for index in range(len(x)) if x[index] == 1]
- class_0 = features_umap#[index_0]
- class_1 = features_umap#[index_1]
+ try:
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
+ class_0 = features_umap#[index_0]
+ class_1 = features_umap#[index_1]
+ except:
+ class_0 = []
+ class_1 = []
print("Generating latent plots : Done")
-
+
return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
elif downstream_model == "Kernel Ridge":
- regressor = KernelRidge(**params)
+ if params == None:
+ regressor = KernelRidge()
+ else:
+ regressor = KernelRidge(**params)
model = TransformedTargetRegressor(regressor=regressor,
transformer=MinMaxScaler(feature_range=(-1, 1))
).fit(x_batch, y_batch)
@@ -401,8 +516,8 @@ def single_modal(model,dataset, downstream_model,params):
verbose=False)
n_samples = np.minimum(1000, len(x_batch))
features_umap = reducer.fit_transform(x_batch[:n_samples])
- try:x = y_batch.values[:n_samples]
- except:x = y_batch[:n_samples]
+ try: x = y_batch.values[:n_samples]
+ except: x = y_batch[:n_samples]
# index_0 = [index for index in range(len(x)) if x[index] == 0]
# index_1 = [index for index in range(len(x)) if x[index] == 1]
@@ -414,7 +529,10 @@ def single_modal(model,dataset, downstream_model,params):
elif downstream_model == "Linear Regression":
- regressor = LinearRegression(**params)
+ if params == None:
+ regressor = LinearRegression()
+ else:
+ regressor = LinearRegression(**params)
model = TransformedTargetRegressor(regressor=regressor,
transformer=MinMaxScaler(feature_range=(-1, 1))
).fit(x_batch, y_batch)
@@ -431,7 +549,7 @@ def single_modal(model,dataset, downstream_model,params):
n_samples = np.minimum(1000, len(x_batch))
features_umap = reducer.fit_transform(x_batch[:n_samples])
try:x = y_batch.values[:n_samples]
- except:x = y_batch[:n_samples]
+ except: x = y_batch[:n_samples]
# index_0 = [index for index in range(len(x)) if x[index] == 0]
# index_1 = [index for index in range(len(x)) if x[index] == 1]
@@ -460,7 +578,7 @@ def single_modal(model,dataset, downstream_model,params):
n_samples = np.minimum(1000, len(x_batch))
features_umap = reducer.fit_transform(x_batch[:n_samples])
try:x = y_batch.values[:n_samples]
- except:x = y_batch[:n_samples]
+ except: x = y_batch[:n_samples]
# index_0 = [index for index in range(len(x)) if x[index] == 0]
# index_1 = [index for index in range(len(x)) if x[index] == 1]
@@ -469,10 +587,10 @@ def single_modal(model,dataset, downstream_model,params):
print("Generating latent plots : Done")
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
+
-
-def multi_modal(model_list,dataset, downstream_model,params):
- print(model_list)
+def multi_modal(model_list,dataset=None, downstream_model=None,params=None, x_train=None, x_test=None, y_train=None, y_test=None):
+ #print(model_list)
data = avail_datasets()
df = pd.DataFrame(data)
list(df["Dataset"].values)
@@ -480,7 +598,7 @@ def multi_modal(model_list,dataset, downstream_model,params):
if dataset in list(df["Dataset"].values):
task = dataset
predefined = True
- else:
+ elif x_train==None:
predefined = False
components = dataset.split(",")
train_data = pd.read_csv(components[0])[components[2]]
@@ -490,13 +608,18 @@ def multi_modal(model_list,dataset, downstream_model,params):
y_batch_test = pd.read_csv(components[1])[components[3]]
print("Custom Dataset loaded")
-
+ else:
+ predefined = False
+ y_batch = y_train
+ y_batch_test = y_test
+ train_data = x_train
+ test_data = x_test
data = avail_models(raw=True)
df = pd.DataFrame(data)
list(df["Name"].values)
- alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED":"smi-ted"}
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl","SMI-TED":"smi-ted", "Mordred": "Mordred", "MorganFingerprint": "MorganFingerprint"}
#if set(model_list).issubset(list(df["Name"].values)):
if set(model_list).issubset(list(alias.keys())):
for i, model in enumerate(model_list):
@@ -507,7 +630,7 @@ def multi_modal(model_list,dataset, downstream_model,params):
if i == 0:
if predefined:
- with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
+ with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
print(f" Loaded representation/{task}_{model_type}.pkl")
else:
@@ -517,7 +640,7 @@ def multi_modal(model_list,dataset, downstream_model,params):
else:
if predefined:
- with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
+ with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
print(f" Loaded representation/{task}_{model_type}.pkl")
else:
@@ -528,7 +651,6 @@ def multi_modal(model_list,dataset, downstream_model,params):
x_batch = pd.concat([x_batch, x_batch_1], axis=1)
x_batch_test = pd.concat([x_batch_test, x_batch_test_1], axis=1)
-
else:
print("Model not available")
return
@@ -538,11 +660,31 @@ def multi_modal(model_list,dataset, downstream_model,params):
num_columns = x_batch.shape[1]
x_batch.columns = [f'{i + 1}' for i in range(num_columns)]
-
+
+ # exclude row containing Nan value
+ if isinstance(x_batch, torch.Tensor):
+ x_batch = pd.DataFrame(x_batch)
+ nan_indices = x_batch.index[x_batch.isna().any(axis=1)]
+ if len(nan_indices) > 0:
+ x_batch.dropna(inplace = True)
+ for index in sorted(nan_indices, reverse=True):
+ del y_batch[index]
+ print(f'x_batch Nan index: {nan_indices}')
+ print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}')
+
+ if isinstance(x_batch_test, torch.Tensor):
+ x_batch_test = pd.DataFrame(x_batch_test)
+ nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)]
+ if len(nan_indices) > 0:
+ x_batch_test.dropna(inplace = True)
+ for index in sorted(nan_indices, reverse=True):
+ del y_batch_test[index]
+ print(f'x_batch_test Nan index: {nan_indices}')
+ print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
print(f"Representations loaded successfully")
try:
- with open(f"./plot_emb/{task}_multi.pkl", "rb") as f1:
+ with open(f"plot_emb/{task}_multi.pkl", "rb") as f1:
class_0, class_1 = pickle.load(f1)
except:
print("Generating latent plots")
@@ -552,8 +694,8 @@ def multi_modal(model_list,dataset, downstream_model,params):
features_umap = reducer.fit_transform(x_batch[:n_samples])
if "Classifier" in downstream_model:
- try:x = y_batch.values[:n_samples]
- except:x = y_batch[:n_samples]
+ try: x = y_batch.values[:n_samples]
+ except: x = y_batch[:n_samples]
index_0 = [index for index in range(len(x)) if x[index] == 0]
index_1 = [index for index in range(len(x)) if x[index] == 1]
@@ -570,7 +712,10 @@ def multi_modal(model_list,dataset, downstream_model,params):
if downstream_model == "XGBClassifier":
- xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10)
+ if params == None:
+ xgb_predict_concat = XGBClassifier()
+ else:
+ xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10)
xgb_predict_concat.fit(x_batch, y_batch)
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
@@ -608,21 +753,27 @@ def multi_modal(model_list,dataset, downstream_model,params):
return result, roc_auc,fpr, tpr, class_0, class_1
elif downstream_model == "SVR":
- regressor = SVR(**params)
+ if params == None:
+ regressor = SVR()
+ else:
+ regressor = SVR(**params)
model = TransformedTargetRegressor(regressor= regressor,
transformer = MinMaxScaler(feature_range=(-1, 1))
).fit(x_batch,y_batch)
-
+
y_prob = model.predict(x_batch_test)
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
-
+
print(f"RMSE Score: {RMSE_score:.4f}")
result = f"RMSE Score: {RMSE_score:.4f}"
-
+
return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
elif downstream_model == "Linear Regression":
- regressor = LinearRegression(**params)
+ if params == None:
+ regressor = LinearRegression()
+ else:
+ regressor = LinearRegression(**params)
model = TransformedTargetRegressor(regressor=regressor,
transformer=MinMaxScaler(feature_range=(-1, 1))
).fit(x_batch, y_batch)
@@ -636,7 +787,10 @@ def multi_modal(model_list,dataset, downstream_model,params):
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
elif downstream_model == "Kernel Ridge":
- regressor = KernelRidge(**params)
+ if params == None:
+ regressor = KernelRidge()
+ else:
+ regressor = KernelRidge(**params)
model = TransformedTargetRegressor(regressor=regressor,
transformer=MinMaxScaler(feature_range=(-1, 1))
).fit(x_batch, y_batch)
@@ -665,6 +819,144 @@ def multi_modal(model_list,dataset, downstream_model,params):
+def finetune_optuna(x_batch,y_batch, x_batch_test, y_test ):
+ print(f" Finetuning with Optuna and calculating ROC AUC Score ...")
+ X_train = x_batch.values
+ y_train = y_batch.values
+ X_test = x_batch_test.values
+ y_test = y_test.values
+ def objective(trial):
+ # Define parameters to be optimized
+ params = {
+ # 'objective': 'binary:logistic',
+ 'eval_metric': 'auc',
+ 'verbosity': 0,
+ 'n_estimators': trial.suggest_int('n_estimators', 1000, 10000),
+ # 'booster': trial.suggest_categorical('booster', ['gbtree', 'gblinear', 'dart']),
+ # 'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0),
+ 'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0),
+ 'max_depth': trial.suggest_int('max_depth', 1, 12),
+ # 'eta': trial.suggest_loguniform('eta', 1e-8, 1.0),
+ # 'gamma': trial.suggest_loguniform('gamma', 1e-8, 1.0),
+ # 'grow_policy': trial.suggest_categorical('grow_policy', ['depthwise', 'lossguide']),
+ # "subsample": trial.suggest_float("subsample", 0.05, 1.0),
+ # "colsample_bytree": trial.suggest_float("colsample_bytree", 0.05, 1.0),
+ }
+
+ # Train XGBoost model
+ dtrain = xgb.DMatrix(X_train, label=y_train)
+ dtest = xgb.DMatrix(X_test, label=y_test)
+
+ model = xgb.train(params, dtrain)
+
+ # Predict probabilities
+ y_pred = model.predict(dtest)
+
+ # Calculate ROC AUC score
+ roc_auc = roc_auc_score(y_test, y_pred)
+ print("ROC_AUC : ", roc_auc)
+
+ return roc_auc
+
+def add_new_model():
+ models = avail_models(raw=True)
+
+ # Function to display models
+ def display_models():
+ for model in models:
+ model_display = f"Name: {model['Name']}, Description: {model['Description']}, Timestamp: {model['Timestamp']}"
+ print(model_display)
+
+ # Function to update models
+ def update_models(new_name, new_description, new_path):
+ new_model = {
+ "Name": new_name,
+ "Description": new_description,
+ "Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ #"path": new_path
+ }
+ models.append(new_model)
+ with open("models.json", "w") as outfile:
+ json.dump(models, outfile)
+
+ print("Model uploaded and updated successfully!")
+ list_models()
+ #display_models()
+
+ # Widgets
+ name_text = widgets.Text(description="Name:", layout=Layout(width='50%'))
+ description_text = widgets.Text(description="Description:", layout=Layout(width='50%'))
+ path_text = widgets.Text(description="Path:", layout=Layout(width='50%'))
+
+ def browse_callback(b):
+ root = tk.Tk()
+ root.withdraw() # Hide the main window
+ file_path = filedialog.askopenfilename(title="Select a Model File")
+ if file_path:
+ path_text.value = file_path
+
+ browse_button = widgets.Button(description="Browse")
+ browse_button.on_click(browse_callback)
+
+ def submit_callback(b):
+ update_models(name_text.value, description_text.value, path_text.value)
+
+ submit_button = widgets.Button(description="Submit")
+ submit_button.on_click(submit_callback)
+
+ # Display widgets
+ display(VBox([name_text, description_text, path_text, browse_button, submit_button]))
+
+
+def add_new_dataset():
+ # Sample data
+ datasets = avail_datasets()
+
+ # Function to display models
+ def display_datasets():
+ for dataset in datasets:
+ dataset_display = f"Name: {dataset['Dataset']}, Input: {dataset['Input']},Output: {dataset['Output']},Path: {dataset['Path']}, Timestamp: {dataset['Timestamp']}"
+
+ # Function to update models
+ def update_datasets(new_dataset, new_input, new_output, new_path):
+ new_model = {
+ "Dataset": new_dataset,
+ "Input": new_input,
+ "Output": new_output,
+ "Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ "Path": os.path.basename(new_path)
+ }
+ datasets.append(new_model)
+ with open("datasets.json", "w") as outfile:
+ json.dump(datasets, outfile)
+
+ print("Dataset uploaded and updated successfully!")
+ list_data()
+
+
+ # Widgets
+ dataset_text = widgets.Text(description="Dataset:", layout=Layout(width='50%'))
+ input_text = widgets.Text(description="Input:", layout=Layout(width='50%'))
+ output_text = widgets.Text(description="Output:", layout=Layout(width='50%'))
+ path_text = widgets.Text(description="Path:", layout=Layout(width='50%'))
+
+ def browse_callback(b):
+ root = tk.Tk()
+ root.withdraw() # Hide the main window
+ file_path = filedialog.askopenfilename(title="Select a Dataset File")
+ if file_path:
+ path_text.value = file_path
+
+ browse_button = widgets.Button(description="Browse")
+ browse_button.on_click(browse_callback)
+
+ def submit_callback(b):
+ update_datasets(dataset_text.value, input_text.value, output_text.value, path_text.value)
+
+ submit_button = widgets.Button(description="Submit")
+ submit_button.on_click(submit_callback)
+
+ display(VBox([dataset_text, input_text, output_text, path_text, browse_button, submit_button]))
diff --git a/models/mhg_model/README.md b/models/mhg_model/README.md
index b855ff28edd655aedc5097cae88fbb812dd06f76..339698f2033bd48e9e66a67c7c8ba6ce5cb9a626 100644
--- a/models/mhg_model/README.md
+++ b/models/mhg_model/README.md
@@ -27,7 +27,7 @@ In addition, the decoder inherits the theoretical guarantee of MHG on always gen
### Pretrained Models and Training Logs
-We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
+We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.mhg-ged/blob/main/mhggnn_pretrained_model_0724_2023.pickle)
Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
diff --git a/models/mhg_model/images/mhg_example.png b/models/mhg_model/images/mhg_example.png
index 3a7dd8ce73476fba75ed242e67147946d99740eb..816da8f712d997ef5f45ae365bd32b2da0ca62f1 100644
Binary files a/models/mhg_model/images/mhg_example.png and b/models/mhg_model/images/mhg_example.png differ
diff --git a/models/mhg_model/images/mhg_example1.png b/models/mhg_model/images/mhg_example1.png
index 150b71f10580655433a6f59a60cbc2afc07d8dc8..089cdde868fc15c8c9dfce84f3bcbbb650901da1 100644
Binary files a/models/mhg_model/images/mhg_example1.png and b/models/mhg_model/images/mhg_example1.png differ
diff --git a/models/mhg_model/images/mhg_example2.png b/models/mhg_model/images/mhg_example2.png
index b00f97a7fb3bec25c0e6e42990d18aaa216eff2d..87c8ebad807ef7dff641d217a0997ae47ca24ed5 100644
Binary files a/models/mhg_model/images/mhg_example2.png and b/models/mhg_model/images/mhg_example2.png differ
diff --git a/models/mhg_model/load.py b/models/mhg_model/load.py
index 09c8042cdf58b3f657d056955c450d4962a0fe52..322b43ce2e683a8d977dba6412c020553b469838 100644
--- a/models/mhg_model/load.py
+++ b/models/mhg_model/load.py
@@ -17,6 +17,7 @@ from typing_extensions import Self
from .graph_grammar.io.smi import hg_to_mol
from .models.mhgvae import GrammarGINVAE
+
from huggingface_hub import hf_hub_download
@@ -73,12 +74,30 @@ class PretrainedModelWrapper:
return output
-def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
+def load(model_name: str = "mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
PretrainedModelWrapper]:
+
repo_id = "ibm/materials.mhg-ged"
- filename = "mhggnn_pretrained_model_0724_2023.pickle"
+ filename = "pytorch_model.bin" #"mhggnn_pretrained_model_0724_2023.pickle"
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(file_path, "rb") as f:
- model_dict = pickle.load(f)
+ model_dict = torch.load(f)
return PretrainedModelWrapper(model_dict)
+
+
+ """try:
+ if os.path.isfile(model_name):
+ with open(model_name, "rb") as f:
+ model_dict = pickle.load(f)
+ print("MHG Model Loaded")
+ return PretrainedModelWrapper(model_dict)
+
+ except:
+
+ for p in sys.path:
+ file = p + "/" + model_name
+ if os.path.isfile(file):
+ with open(file, "rb") as f:
+ model_dict = pickle.load(f)
+ return PretrainedModelWrapper(model_dict)"""
return None
diff --git a/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf b/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf
index a7dcc1270d1f444f77366013ad2d3d93ebb426ab..4bf1999e79d46e23da49a337a02dd6f189f4086a 100644
Binary files a/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf and b/models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf differ
diff --git a/models/selfies_model/selfies-ted.png b/models/selfies_model/selfies-ted.png
index a71127e0a6baf0110d8074e63a08ba060c931121..d1c3561c3751785b0507d966a01ea1fe3b859fa3 100644
Binary files a/models/selfies_model/selfies-ted.png and b/models/selfies_model/selfies-ted.png differ
diff --git a/models/selfies_ted/README.md b/models/selfies_ted/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..01f70f739727bf443957cdef04353175e0c4f47f
--- /dev/null
+++ b/models/selfies_ted/README.md
@@ -0,0 +1,87 @@
+---
+license: apache-2.0
+library_name: transformers
+pipeline_tag: feature-extraction
+tags:
+- chemistry
+---
+
+# selfies-ted
+
+selfies-ted is a project for encoding SMILES (Simplified Molecular Input Line Entry System) into SELFIES (SELF-referencing Embedded Strings) and generating embeddings for molecular representations.
+
+![selfies-ted](selfies-ted.png)
+## Model Architecture
+
+Configuration details
+
+Encoder and Decoder FFN dimensions: 256
+Number of attention heads: 4
+Number of encoder and decoder layers: 2
+Total number of hidden layers: 6
+Maximum position embeddings: 128
+Model dimension (d_model): 256
+
+## Pretrained Models and Training Logs
+We provide checkpoints of the selfies-ted model pre-trained on a dataset of molecules curated from PubChem. The pre-trained model shows competitive performance on molecular representation tasks. For model weights: "HuggingFace link".
+
+To install and use the pre-trained model:
+
+Download the selfies_ted_model.pkl file from the "HuggingFace link".
+Add the selfies-ted selfies_ted_model.pkl to the models/ directory. The directory structure should look like the following:
+
+```
+models/
+└── selfies_ted_model.pkl
+```
+
+## Installation
+
+To use this project, you'll need to install the required dependencies. We recommend using a virtual environment:
+
+```bash
+python -m venv venv
+source venv/bin/activate # On Windows use `venv\Scripts\activate`
+```
+
+Install the required dependencies
+
+```
+pip install -r requirements.txt
+```
+
+
+## Usage
+
+### Import
+
+```
+import load
+```
+### Training the Model
+
+To train the model, use the train.py script:
+
+```
+python train.py -f
+```
+
+
+Note: The actual usage may depend on the specific implementation in load.py. Please refer to the source code for detailed functionality.
+
+### Load the model and tokenizer
+```
+load.load("path/to/checkpoint.pkl")
+```
+### Encode SMILES strings
+```
+smiles_list = ["COC", "CCO"]
+```
+```
+embeddings = load.encode(smiles_list)
+```
+
+
+## Example Notebook
+
+Example notebook of this project is `selfies-ted-example.ipynb`.
diff --git a/models/selfies_ted/load.py b/models/selfies_ted/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec37d3df32957e79b29e9cf71ee49d3690f9a23
--- /dev/null
+++ b/models/selfies_ted/load.py
@@ -0,0 +1,92 @@
+import os
+import sys
+import torch
+import selfies as sf # selfies>=2.1.1
+import pickle
+import pandas as pd
+import numpy as np
+from datasets import Dataset
+from rdkit import Chem
+from transformers import AutoTokenizer, AutoModel
+
+
+class SELFIES(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.model = None
+ self.tokenizer = None
+ self.invalid = []
+
+ def get_selfies(self, smiles_list):
+ self.invalid = []
+ spaced_selfies_batch = []
+ for i, smiles in enumerate(smiles_list):
+ try:
+ selfies = sf.encoder(smiles.rstrip())
+ except:
+ try:
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.rstrip()))
+ selfies = sf.encoder(smiles)
+ except:
+ selfies = "[]"
+ self.invalid.append(i)
+
+ spaced_selfies_batch.append(selfies.replace('][', '] ['))
+
+ return spaced_selfies_batch
+
+
+ def get_embedding(self, selfies):
+ encoding = self.tokenizer(selfies["selfies"], return_tensors='pt', max_length=128, truncation=True, padding='max_length')
+ input_ids = encoding['input_ids']
+ attention_mask = encoding['attention_mask']
+ outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
+ model_output = outputs.last_hidden_state
+
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
+ sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+ model_output = sum_embeddings / sum_mask
+
+ del encoding['input_ids']
+ del encoding['attention_mask']
+
+ encoding["embedding"] = model_output
+
+ return encoding
+
+
+ def load(self, checkpoint="bart-2908.pickle"):
+ """
+ inputs :
+ checkpoint (pickle object)
+ """
+
+ self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
+ self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted")
+
+
+
+
+
+ # TODO: remove `use_gpu` argument in validation pipeline
+ def encode(self, smiles_list=[], use_gpu=False, return_tensor=False):
+ """
+ inputs :
+ checkpoint (pickle object)
+ :return: embedding
+ """
+ selfies = self.get_selfies(smiles_list)
+ selfies_df = pd.DataFrame(selfies,columns=["selfies"])
+ data = Dataset.from_pandas(selfies_df)
+ embedding = data.map(self.get_embedding, batched=True, num_proc=1, batch_size=128)
+ emb = np.asarray(embedding["embedding"].copy())
+
+ for idx in self.invalid:
+ emb[idx] = np.nan
+ print("Cannot encode {0} to selfies and embedding replaced by NaN".format(smiles_list[idx]))
+
+ if return_tensor:
+ return torch.tensor(emb)
+ return pd.DataFrame(emb)
diff --git a/models/selfies_ted/requirements.txt b/models/selfies_ted/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9183360cca79111e2e64fe4b65849e4df75c195f
--- /dev/null
+++ b/models/selfies_ted/requirements.txt
@@ -0,0 +1,12 @@
+torch>=2.1.0
+transformers>=4.38
+numpy>=1.26.1
+datasets>=2.13.1
+evaluate>=0.4.0
+selfies>=2.1.0
+scikit-learn>=1.2.1
+pyarrow>=14.0.1
+requests>=2.31.0
+urllib3>=2.0.7
+aiohttp>=3.9.0
+zipp>=3.17.0
\ No newline at end of file
diff --git a/models/selfies_ted/selfies-ted-example.ipynb b/models/selfies_ted/selfies-ted-example.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..856f98cde351f6fa5b3fcfdebd2c5ad6726fc380
--- /dev/null
+++ b/models/selfies_ted/selfies-ted-example.ipynb
@@ -0,0 +1,136 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "9d9b6eb8-9edb-44bd-9e5a-3a6ea67f5117",
+ "metadata": {},
+ "source": [
+ "### Import library"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "c3ac4418",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from load import SELFIES"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "790061cf-5470-4564-987e-aa2e492337db",
+ "metadata": {},
+ "source": [
+ "### Initialize and load"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "85847f26-e2f4-475a-a88e-41fd9cccfc0f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = SELFIES()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "095e864c",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "model.load(checkpoint=\"bart-2908.pickle\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "55f1a68c-c462-4dee-9139-9befb469f176",
+ "metadata": {},
+ "source": [
+ "### Example to get embeddings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "2357ef0a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b494cbf9878a4f5c8f4093e38fb82fd5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/3 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "smiles_list = [\"CCO\", \"O=C=O\", \"OC(=O)c1ccccc1C(=O)O\"]\n",
+ "embeddings = model.encode(smiles_list)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "3871c513-d0a9-4e70-9c18-3f0b491e07b2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(3, 1024)"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "embeddings.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "289a8795-d6d8-4828-b2b2-b4d4a97a4604",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/models/selfies_ted/selfies-ted.png b/models/selfies_ted/selfies-ted.png
new file mode 100644
index 0000000000000000000000000000000000000000..d1c3561c3751785b0507d966a01ea1fe3b859fa3
--- /dev/null
+++ b/models/selfies_ted/selfies-ted.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1229d74cd9473344d9907f5b8b2ae22694bdd77e94d3ae8f1f8dadacf538ee9e
+size 47631
diff --git a/models/smi_ted/.gitignore b/models/smi_ted/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..594f1b8f3000cc613695e1704763213e5697f4f8
--- /dev/null
+++ b/models/smi_ted/.gitignore
@@ -0,0 +1,18 @@
+# Model weights
+inference/smi_ted_light/smi-ted-Light_40.pt
+
+# pyenv
+.python-version
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# editor files
+.vscode/
+.DS_Store
diff --git a/models/smi_ted/README.md b/models/smi_ted/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a4bd49c6aa2586e33c3baa2cd5caf438d0253756
--- /dev/null
+++ b/models/smi_ted/README.md
@@ -0,0 +1,138 @@
+# SMILES-based Transformer Encoder-Decoder (SMI-TED)
+
+This repository provides PyTorch source code associated with our publication, "A Large Encoder-Decoder Family of Foundation Models for Chemical Language".
+
+**Paper:** [Arxiv Link](https://arxiv.org/abs/2407.20267)
+
+**HuggingFace:** [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted)
+
+For more information contact: eduardo.soares@ibm.com or evital@br.ibm.com.
+
+![ted-smi](images/smi-ted.png)
+
+## Introduction
+
+We present a large encoder-decoder chemical foundation model, SMILES-based Transformer Encoder-Decoder (SMI-TED), pre-trained on a curated dataset of 91 million SMILES samples sourced from PubChem, equivalent to 4 billion molecular tokens. SMI-TED supports various complex tasks, including quantum property prediction, with two main variants ($289M$ and $8 \times 289M$). Our experiments across multiple benchmark datasets demonstrate state-of-the-art performance for various tasks. Model weights are available at: [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted).
+
+## Table of Contents
+
+1. [Getting Started](#getting-started)
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
+ 2. [Replicating Conda Environment](#replicating-conda-environment)
+2. [Pretraining](#pretraining)
+3. [Finetuning](#finetuning)
+4. [Feature Extraction](#feature-extraction)
+5. [Citations](#citations)
+
+## Getting Started
+
+**This code and environment have been tested on Nvidia V100s and Nvidia A100s**
+
+### Pretrained Models and Training Logs
+
+We provide checkpoints of the SMI-TED model pre-trained on a dataset of ~91M molecules curated from PubChem. The pre-trained model shows competitive performance on classification and regression benchmarks from MoleculeNet. For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted)
+
+Add the SMI-TED `pre-trained weights.pt` to the `inference/` or `finetune/` directory according to your needs. The directory structure should look like the following:
+
+```
+inference/
+├── smi_ted_light
+│ ├── smi_ted_light.pt
+│ ├── bert_vocab_curated.txt
+│ └── load.py
+```
+and/or:
+
+```
+finetune/
+├── smi_ted_light
+│ ├── smi_ted_light.pt
+│ ├── bert_vocab_curated.txt
+│ └── load.py
+```
+
+### Replicating Conda Environment
+
+Follow these steps to replicate our Conda environment and install the necessary libraries:
+
+#### Create and Activate Conda Environment
+
+```
+conda create --name smi-ted-env python=3.10
+conda activate smi-ted-env
+```
+
+#### Install Packages with Conda
+
+```
+conda install pytorch=2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
+```
+
+#### Install Packages with Pip
+
+```
+pip install -r requirements.txt
+pip install pytorch-fast-transformers
+```
+
+## Pretraining
+
+For pretraining, we use two strategies: the masked language model method to train the encoder part and an encoder-decoder strategy to refine SMILES reconstruction and improve the generated latent space.
+
+SMI-TED is pre-trained on canonicalized and curated 91M SMILES from PubChem with the following constraints:
+
+- Compounds are filtered to a maximum length of 202 tokens during preprocessing.
+- A 95/5/0 split is used for encoder training, with 5% of the data for decoder pretraining.
+- A 100/0/0 split is also used to train the encoder and decoder directly, enhancing model performance.
+
+The pretraining code provides examples of data processing and model training on a smaller dataset, requiring 8 A100 GPUs.
+
+To pre-train the two variants of the SMI-TED model, run:
+
+```
+bash training/run_model_light_training.sh
+```
+or
+```
+bash training/run_model_large_training.sh
+```
+
+Use `train_model_D.py` to train only the decoder or `train_model_ED.py` to train both the encoder and decoder.
+
+## Finetuning
+
+The finetuning datasets and environment can be found in the [finetune](finetune/) directory. After setting up the environment, you can run a finetuning task with:
+
+```
+bash finetune/smi_ted_light/esol/run_finetune_esol.sh
+```
+
+Finetuning training/checkpointing resources will be available in directories named `checkpoint_`.
+
+## Feature Extraction
+
+The example notebook [smi_ted_encoder_decoder_example.ipynb](notebooks/smi_ted_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks. It also includes examples of classification and regression tasks. For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.smi-ted)
+
+To load smi-ted, you can simply use:
+
+```python
+model = load_smi_ted(
+ folder='../inference/smi_ted_light',
+ ckpt_filename='smi_ted_light.pt'
+)
+```
+
+To encode SMILES into embeddings, you can use:
+
+```python
+with torch.no_grad():
+ encoded_embeddings = model.encode(df['SMILES'], return_torch=True)
+```
+For decoder, you can use the function, so you can return from embeddings to SMILES strings:
+
+```python
+with torch.no_grad():
+ decoded_smiles = model.decode(encoded_embeddings)
+```
+
+
diff --git a/models/smi_ted/finetune/args.py b/models/smi_ted/finetune/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a698274dafef1671059da9456dbdce404caaafc
--- /dev/null
+++ b/models/smi_ted/finetune/args.py
@@ -0,0 +1,337 @@
+import argparse
+
+
+def get_parser(parser=None):
+ if parser is None:
+ parser = argparse.ArgumentParser()
+
+ # Model
+ # model_arg = parser.add_argument_group('Model')
+ parser.add_argument("--n_head", type=int, default=8, help="GPT number of heads")
+ parser.add_argument("--n_layer", type=int, default=12, help="GPT number of layers")
+ parser.add_argument(
+ "--q_dropout", type=float, default=0.5, help="Encoder layers dropout"
+ )
+ parser.add_argument(
+ "--d_dropout", type=float, default=0.1, help="Decoder layers dropout"
+ )
+ parser.add_argument(
+ "--n_embd", type=int, default=768, help="Latent vector dimensionality"
+ )
+ parser.add_argument(
+ "--fc_h", type=int, default=512, help="Fully connected hidden dimensionality"
+ )
+ parser.add_argument("--n_output", type=int, default=1)
+
+ # Train
+ # train_arg = parser.add_argument_group('Train')
+ parser.add_argument("--n_batch", type=int, default=512, help="Batch size")
+ parser.add_argument(
+ "--unlike_alpha", type=float, default=1.0, help="unlikelihood loss alpha weight"
+ )
+ parser.add_argument(
+ "--from_scratch",
+ action="store_true",
+ default=False,
+ help="train on qm9 from scratch",
+ )
+ parser.add_argument(
+ "--unlikelihood",
+ action="store_true",
+ default=False,
+ help="use unlikelihood loss with gpt pretrain",
+ )
+ parser.add_argument(
+ "--grad_acc",
+ type=int,
+ default=1,
+ help="number of batches to accumulate gradients",
+ )
+ parser.add_argument(
+ "--checkpoint_every",
+ type=int,
+ default=1000,
+ help="save checkpoint every x iterations",
+ )
+ parser.add_argument(
+ "--clip_grad", type=int, default=50, help="Clip gradients to this value"
+ )
+ parser.add_argument(
+ "--lr_start", type=float, default=3 * 1e-4, help="Initial lr value"
+ )
+ parser.add_argument(
+ "--lr_end", type=float, default=3 * 1e-4, help="Maximum lr weight value"
+ )
+ parser.add_argument(
+ "--lr_multiplier", type=int, default=1, help="lr weight multiplier"
+ )
+ parser.add_argument(
+ "--n_last", type=int, default=1000, help="Number of iters to smooth loss calc"
+ )
+ parser.add_argument("--n_jobs", type=int, default=1, help="Number of threads")
+ parser.add_argument(
+ "--accelerator",
+ type=str,
+ default="ddp",
+ help="The accelerator backend to use (previously known as distributed_backend)",
+ )
+ parser.add_argument(
+ "--num_nodes",
+ type=int,
+ default=1,
+ help="number of GPU nodes for distributed training",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help='Device to run: "cpu" or "cuda:"',
+ )
+ parser.add_argument("--seed", type=int, default=12345, help="Seed")
+ parser.add_argument(
+ "--init_params_from",
+ type=str,
+ default="",
+ help="Path to a ckpt used to initialize the parameters if no restart_path is provided",
+ )
+ parser.add_argument(
+ "--train_decoder_every",
+ type=int,
+ default=10,
+ help="Optimize decoder params every n batches",
+ )
+ parser.add_argument(
+ "--lr_decoder", type=float, default=1e-4, help="Learning rate for decoder part"
+ )
+ parser.add_argument(
+ "--local_rank",
+ type=int,
+ default=-1,
+ help="local_rank for distributed training on gpus",
+ )
+ parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
+ parser.add_argument(
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
+ )
+ parser.add_argument(
+ "--tensorboard_path", default="./runs/deepspeed", help="tensorboard log dir"
+ )
+
+ # common_arg = parser.add_argument_group('Common')
+ parser.add_argument(
+ "--vocab_load", type=str, required=False, help="Where to load the vocab"
+ )
+ parser.add_argument(
+ "--n_samples", type=int, required=False, help="Number of samples to sample"
+ )
+ parser.add_argument(
+ "--gen_save", type=str, required=False, help="Where to save the gen molecules"
+ )
+ parser.add_argument(
+ "--max_len", type=int, default=100, help="Max of length of SMILES"
+ )
+ parser.add_argument(
+ "--train_load", type=str, required=False, help="Where to load the model"
+ )
+ parser.add_argument(
+ "--val_load", type=str, required=False, help="Where to load the model"
+ )
+ parser.add_argument(
+ "--n_workers",
+ type=int,
+ required=False,
+ default=1,
+ help="Where to load the model",
+ )
+ # beam search hyper parameters
+ parser.add_argument(
+ "--beam_size", type=int, default=0, help="Number of beams to generate"
+ )
+ parser.add_argument(
+ "--num_seq_returned",
+ type=int,
+ default=0,
+ help="number of beams to be returned (must be <= beam_size",
+ )
+ parser.add_argument(
+ "--min_len", type=int, default=1, help="minimum length to be generated"
+ )
+ parser.add_argument(
+ "--nucleus_thresh", type=float, default=0.9, help="nucleus sampling threshold"
+ )
+ parser.add_argument(
+ "--finetune_path",
+ type=str,
+ default="",
+ help="path to trainer file to continue training",
+ )
+ parser.add_argument(
+ "--restart_path",
+ type=str,
+ default="",
+ help="path to trainer file to continue training",
+ )
+ parser.add_argument(
+ "--data_path", type=str, default="", help="path to pubchem file"
+ )
+ parser.add_argument(
+ "--pretext_size", type=int, default=0, help="number of k-mers to pretext"
+ )
+ parser.add_argument(
+ "--model_save_dir",
+ type=str,
+ required=False,
+ default="./models_dump/",
+ help="Where to save the models/log/config/vocab",
+ )
+ parser.add_argument(
+ "--model_save",
+ type=str,
+ required=False,
+ default="model.pt",
+ help="Where to save the model",
+ )
+ # parser.add_argument('--save_frequency',
+ # type=int, default=20,
+ # help='How often to save the model')
+ parser.add_argument(
+ "--num_epoch", type=int, default=1, help="number of epochs to train"
+ )
+ # parser.add_argument('--num_iter',
+ # type=int, default=-1,
+ # help='how many itersations per epoch (for unlikelihood tuning)')
+ parser.add_argument(
+ "--log_file", type=str, required=False, help="Where to save the log"
+ )
+ parser.add_argument(
+ "--tb_loc",
+ type=str,
+ required=False,
+ help="Where to save the tensorflow location",
+ )
+ parser.add_argument(
+ "--config_save", type=str, required=False, help="Where to save the config"
+ )
+ parser.add_argument("--vocab_save", type=str, help="Where to save the vocab")
+
+ # resume_arg = parser.add_argument_group('Resume')
+ parser.add_argument(
+ "--debug",
+ default=False,
+ action="store_true",
+ help="do not erase cache at end of program",
+ )
+ parser.add_argument(
+ "--fast_dev_run",
+ default=False,
+ help="This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).",
+ )
+ parser.add_argument(
+ "--freeze_model",
+ default=False,
+ action="store_true",
+ help="freeze weights of bert model during fine tuning",
+ )
+ parser.add_argument(
+ "--resume", default=False, action="store_true", help="Resume from a saved model"
+ )
+ parser.add_argument(
+ "--rotate",
+ default=False,
+ action="store_true",
+ help="use rotational relative embedding",
+ )
+ parser.add_argument(
+ "--model_load", type=str, required=False, help="Where to load the model"
+ )
+ parser.add_argument(
+ "--root_dir", type=str, required=False, default=".", help="location of root dir"
+ )
+ parser.add_argument(
+ "--config_load", type=str, required=False, help="Where to load the config"
+ )
+ parser.add_argument(
+ "--gpus", type=int, required=False, default=1, help="number of gpus to use"
+ )
+ # parser.add_argument('--start_epoch',
+ # type=int, required=False, default=0,
+ # help='Where to load the config')
+
+ parser.add_argument(
+ "--model_arch",
+ type=str,
+ required=False,
+ help="used to teack model arch in params",
+ )
+ parser.add_argument(
+ "--eval_every",
+ type=int,
+ default=50000,
+ help="run evaluation every x iterations",
+ )
+ parser.add_argument(
+ "--num_feats",
+ type=int,
+ required=False,
+ default=32,
+ help="number of random reatures for FAVOR+",
+ )
+ parser.add_argument(
+ "--max_epochs", type=int, required=False, default=1, help="max number of epochs"
+ )
+
+ # debug() FINE TUNEING
+ # parser.add_argument('--save_dir', type=str, required=True)
+ parser.add_argument(
+ "--mode", type=str, default="cls", help="type of pooling to use"
+ )
+ parser.add_argument("--dataset_length", type=int, default=None, required=False)
+ parser.add_argument("--num_workers", type=int, default=0, required=False)
+ parser.add_argument("--dropout", type=float, default=0.1, required=False)
+ # parser.add_argument("--dims", type=int, nargs="*", default="", required=False)
+ parser.add_argument(
+ "--smiles_embedding",
+ type=str,
+ default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt",
+ )
+ # parser.add_argument("--train_pct", type=str, required=False, default="95")
+ # parser.add_argument("--aug", type=int, required=True)
+ parser.add_argument("--dataset_name", type=str, required=False, default="sol")
+ parser.add_argument("--measure_name", type=str, required=False, default="measure")
+ # parser.add_argument("--emb_type", type=str, required=True)
+ parser.add_argument("--checkpoints_folder", type=str, required=True)
+ # parser.add_argument("--results_dir", type=str, required=True)
+ # parser.add_argument("--patience_epochs", type=int, required=True)
+ parser.add_argument("--model_path", type=str, default="./smi_ted/")
+ parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
+ parser.add_argument("--restart_filename", type=str, default="")
+ # parser.add_argument('--n_output', type=int, default=1)
+ parser.add_argument("--save_every_epoch", type=int, default=0)
+ parser.add_argument("--save_ckpt", type=int, default=1)
+ parser.add_argument("--start_seed", type=int, default=0)
+ parser.add_argument("--smi_ted_version", type=str, default="v1")
+ parser.add_argument("--train_decoder", type=int, default=1)
+ parser.add_argument("--target_metric", type=str, default="rmse")
+ parser.add_argument("--loss_fn", type=str, default="mae")
+
+ parser.add_argument(
+ "--data_root",
+ type=str,
+ required=False,
+ default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity",
+ )
+ # parser.add_argument("--use_bn", type=int, default=0)
+ parser.add_argument("--use_linear", type=int, default=0)
+
+ parser.add_argument("--lr", type=float, default=0.001)
+ # parser.add_argument("--weight_decay", type=float, default=5e-4)
+ # parser.add_argument("--val_check_interval", type=float, default=1.0)
+ parser.add_argument("--batch_size", type=int, default=64)
+
+ return parser
+
+
+def parse_args():
+ parser = get_parser()
+ args = parser.parse_args()
+ return args
diff --git a/models/smi_ted/finetune/finetune_classification.py b/models/smi_ted/finetune/finetune_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e4636f4b45093947838450c3e1a599fdc3d273
--- /dev/null
+++ b/models/smi_ted/finetune/finetune_classification.py
@@ -0,0 +1,68 @@
+# Deep learning
+import torch
+import torch.nn as nn
+from torch import optim
+from trainers import TrainerClassifier
+from utils import get_optim_groups
+
+# Data
+import pandas as pd
+import numpy as np
+
+# Standard library
+import args
+import os
+
+
+def main(config):
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ # load dataset
+ df_train = pd.read_csv(f"{config.data_root}/train.csv")
+ df_valid = pd.read_csv(f"{config.data_root}/valid.csv")
+ df_test = pd.read_csv(f"{config.data_root}/test.csv")
+
+ # load model
+ if config.smi_ted_version == 'v1':
+ from smi_ted_light.load import load_smi_ted
+ elif config.smi_ted_version == 'v2':
+ from smi_ted_large.load import load_smi_ted
+
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False)
+ model.net.apply(model._init_weights)
+ print(model.net)
+
+ lr = config.lr_start*config.lr_multiplier
+ optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder))
+ if config.loss_fn == 'crossentropy':
+ loss_function = nn.CrossEntropyLoss()
+
+ # init trainer
+ trainer = TrainerClassifier(
+ raw_data=(df_train, df_valid, df_test),
+ dataset_name=config.dataset_name,
+ target=config.measure_name,
+ batch_size=config.n_batch,
+ hparams=config,
+ target_metric=config.target_metric,
+ seed=config.start_seed,
+ smi_ted_version=config.smi_ted_version,
+ checkpoints_folder=config.checkpoints_folder,
+ restart_filename=config.restart_filename,
+ device=device,
+ save_every_epoch=bool(config.save_every_epoch),
+ save_ckpt=bool(config.save_ckpt)
+ )
+ trainer.compile(
+ model=model,
+ optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)),
+ loss_fn=loss_function
+ )
+ trainer.fit(max_epochs=config.max_epochs)
+ trainer.evaluate()
+
+
+if __name__ == '__main__':
+ parser = args.get_parser()
+ config = parser.parse_args()
+ main(config)
\ No newline at end of file
diff --git a/models/smi_ted/finetune/finetune_classification_multitask.py b/models/smi_ted/finetune/finetune_classification_multitask.py
new file mode 100644
index 0000000000000000000000000000000000000000..d244f3650db1e76a51e8050be0abdb4a92b168cb
--- /dev/null
+++ b/models/smi_ted/finetune/finetune_classification_multitask.py
@@ -0,0 +1,101 @@
+# Deep learning
+import torch
+import torch.nn as nn
+from torch import optim
+from trainers import TrainerClassifierMultitask
+from utils import get_optim_groups
+
+# Data
+import pandas as pd
+import numpy as np
+
+# Standard library
+import args
+import os
+
+
+def main(config):
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ # Define Target and Causal Features
+ if config.dataset_name == 'tox21':
+ targets = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
+ 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
+ elif config.dataset_name == 'clintox':
+ targets = ['FDA_APPROVED', 'CT_TOX']
+ elif config.dataset_name == 'sider':
+ targets = [
+ 'Hepatobiliary disorders', 'Metabolism and nutrition disorders',
+ 'Product issues', 'Eye disorders', 'Investigations',
+ 'Musculoskeletal and connective tissue disorders',
+ 'Gastrointestinal disorders', 'Social circumstances',
+ 'Immune system disorders', 'Reproductive system and breast disorders',
+ 'Neoplasms benign, malignant and unspecified (incl cysts and polyps)',
+ 'General disorders and administration site conditions',
+ 'Endocrine disorders', 'Surgical and medical procedures',
+ 'Vascular disorders', 'Blood and lymphatic system disorders',
+ 'Skin and subcutaneous tissue disorders',
+ 'Congenital, familial and genetic disorders', 'Infections and infestations',
+ 'Respiratory, thoracic and mediastinal disorders', 'Psychiatric disorders',
+ 'Renal and urinary disorders',
+ 'Pregnancy, puerperium and perinatal conditions',
+ 'Ear and labyrinth disorders', 'Cardiac disorders',
+ 'Nervous system disorders', 'Injury, poisoning and procedural complications'
+ ]
+ elif config.dataset_name == 'muv':
+ targets = [
+ 'MUV-466', 'MUV-548', 'MUV-600', 'MUV-644', 'MUV-652', 'MUV-689',
+ 'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
+ 'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'
+ ]
+ config.n_output = len(targets)
+
+ # load dataset
+ df_train = pd.read_csv(f"{config.data_root}/train.csv")
+ df_valid = pd.read_csv(f"{config.data_root}/valid.csv")
+ df_test = pd.read_csv(f"{config.data_root}/test.csv")
+
+ # load model
+ if config.smi_ted_version == 'v1':
+ from smi_ted_light.load import load_smi_ted
+ elif config.smi_ted_version == 'v2':
+ from smi_ted_large.load import load_smi_ted
+
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=len(targets), eval=False)
+ model.net.apply(model._init_weights)
+ print(model.net)
+
+ lr = config.lr_start*config.lr_multiplier
+ optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder))
+ if config.loss_fn == 'bceloss':
+ loss_function = nn.BCELoss()
+
+ # init trainer
+ trainer = TrainerClassifierMultitask(
+ raw_data=(df_train, df_valid, df_test),
+ dataset_name=config.dataset_name,
+ target=targets,
+ batch_size=config.n_batch,
+ hparams=config,
+ target_metric=config.target_metric,
+ seed=config.start_seed,
+ smi_ted_version=config.smi_ted_version,
+ checkpoints_folder=config.checkpoints_folder,
+ restart_filename=config.restart_filename,
+ device=device,
+ save_every_epoch=bool(config.save_every_epoch),
+ save_ckpt=bool(config.save_ckpt)
+ )
+ trainer.compile(
+ model=model,
+ optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)),
+ loss_fn=loss_function
+ )
+ trainer.fit(max_epochs=config.max_epochs)
+ trainer.evaluate()
+
+
+if __name__ == '__main__':
+ parser = args.get_parser()
+ config = parser.parse_args()
+ main(config)
\ No newline at end of file
diff --git a/models/smi_ted/finetune/finetune_regression.py b/models/smi_ted/finetune/finetune_regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a05d32baa43fd9eb127733f1da00b146d7a5172
--- /dev/null
+++ b/models/smi_ted/finetune/finetune_regression.py
@@ -0,0 +1,70 @@
+# Deep learning
+import torch
+import torch.nn as nn
+from torch import optim
+from trainers import TrainerRegressor
+from utils import RMSELoss, get_optim_groups
+
+# Data
+import pandas as pd
+import numpy as np
+
+# Standard library
+import args
+import os
+
+
+def main(config):
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ # load dataset
+ df_train = pd.read_csv(f"{config.data_root}/train.csv")
+ df_valid = pd.read_csv(f"{config.data_root}/valid.csv")
+ df_test = pd.read_csv(f"{config.data_root}/test.csv")
+
+ # load model
+ if config.smi_ted_version == 'v1':
+ from smi_ted_light.load import load_smi_ted
+ elif config.smi_ted_version == 'v2':
+ from smi_ted_large.load import load_smi_ted
+
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False)
+ model.net.apply(model._init_weights)
+ print(model.net)
+
+ lr = config.lr_start*config.lr_multiplier
+ optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder))
+ if config.loss_fn == 'rmse':
+ loss_function = RMSELoss()
+ elif config.loss_fn == 'mae':
+ loss_function = nn.L1Loss()
+
+ # init trainer
+ trainer = TrainerRegressor(
+ raw_data=(df_train, df_valid, df_test),
+ dataset_name=config.dataset_name,
+ target=config.measure_name,
+ batch_size=config.n_batch,
+ hparams=config,
+ target_metric=config.target_metric,
+ seed=config.start_seed,
+ smi_ted_version=config.smi_ted_version,
+ checkpoints_folder=config.checkpoints_folder,
+ restart_filename=config.restart_filename,
+ device=device,
+ save_every_epoch=bool(config.save_every_epoch),
+ save_ckpt=bool(config.save_ckpt)
+ )
+ trainer.compile(
+ model=model,
+ optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)),
+ loss_fn=loss_function
+ )
+ trainer.fit(max_epochs=config.max_epochs)
+ trainer.evaluate()
+
+
+if __name__ == '__main__':
+ parser = args.get_parser()
+ config = parser.parse_args()
+ main(config)
\ No newline at end of file
diff --git a/models/smi_ted/finetune/moleculenet/bace/test.csv b/models/smi_ted/finetune/moleculenet/bace/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..adc1ccdfbfadfc6b9d6c35079ac028d1a748499d
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/bace/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3af97c680375dd09349c63b4779b35166212302e79e4fc7a1752ef5d71cf35b
+size 400436
diff --git a/models/smi_ted/finetune/moleculenet/bace/train.csv b/models/smi_ted/finetune/moleculenet/bace/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..017b9a6079d76ea84dd61b119ffbc374d765cc09
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/bace/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b5b3426e84dc7e2f40f2cf9d15d4d38328126c07f49c215cfb4fb657f69200de
+size 3109699
diff --git a/models/smi_ted/finetune/moleculenet/bace/valid.csv b/models/smi_ted/finetune/moleculenet/bace/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..11c8f7fbcbe27d30244a8f8d31dd84f35a270e88
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/bace/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:813c8f2af5a1058568cf60b7021b8b2cd818a17944afd0b09f9d838e36ee985d
+size 397085
diff --git a/models/smi_ted/finetune/moleculenet/bbbp/test.csv b/models/smi_ted/finetune/moleculenet/bbbp/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..21089037ffa94aca5db3083f16b887b79bd74212
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/bbbp/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cca4161c44535fd0f8ff917cc68d26703da7fbce19ddecb7dc5f7ae4b4d241a6
+size 14874
diff --git a/models/smi_ted/finetune/moleculenet/bbbp/train.csv b/models/smi_ted/finetune/moleculenet/bbbp/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..314cc5ea086cecbd3d7c0ab9fb96371619aca018
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/bbbp/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7300807bf21ea1177efd81c218e43275ed00b6c3006b5dae7625f774edb6b1a6
+size 115549
diff --git a/models/smi_ted/finetune/moleculenet/bbbp/valid.csv b/models/smi_ted/finetune/moleculenet/bbbp/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..0255eb26d4d514fd9446bf356938161e3e5d7378
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/bbbp/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af39cc3735a356010a072e1e196a64eca6e0d88f0b2a023d4dc1adba7030ce40
+size 15655
diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/biodeg_example.csv b/models/smi_ted/finetune/moleculenet/biodegradability/biodeg_example.csv
new file mode 100644
index 0000000000000000000000000000000000000000..af1df8f88f1194796428d43b11b8c8442feeac15
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/biodegradability/biodeg_example.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c98992c1c22ae7468a41fb7bc86c775ccc30fa29e50053bb148ffc2f2d95551e
+size 6352
diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/biodegradability.csv b/models/smi_ted/finetune/moleculenet/biodegradability/biodegradability.csv
new file mode 100644
index 0000000000000000000000000000000000000000..667fbf1e87eb9753bffe53851749bf0c0accf8e6
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/biodegradability/biodegradability.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ec61887444a0e8925b16cca48433c3b3bff1ac5cf08f448d6b64bbdbc14a318
+size 416181
diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/test.csv b/models/smi_ted/finetune/moleculenet/biodegradability/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..3f89d1cf7d041cc4048d95328f4135f03a98d4e1
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/biodegradability/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86c2f7f39add0fff77358454c0f1b289a233e4a78d50b7f005ec2dc1c632d473
+size 84488
diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/train.csv b/models/smi_ted/finetune/moleculenet/biodegradability/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..92b5d108ee4ef6abe93e0deec05f9b6bac50bbbd
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/biodegradability/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a4a94ae0f8c134ce10f2d853eced84d031a4e7b394662344a9141e7567b3eb2
+size 252230
diff --git a/models/smi_ted/finetune/moleculenet/biodegradability/valid.csv b/models/smi_ted/finetune/moleculenet/biodegradability/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..301cf2278dee811b8afdaf79f771c650af2b4dba
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/biodegradability/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09e827ee7e55544f5b327d5e2ef2d9fe09e3f62024e1316b6e71d1fc9be275a1
+size 85290
diff --git a/models/smi_ted/finetune/moleculenet/clintox/test.csv b/models/smi_ted/finetune/moleculenet/clintox/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..58483d5b0478e7cbabb603c70177ab8d1ac0157a
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/clintox/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:963a05e8eeaaa38fd3688f448dfc28cd0917ea280b1b9cb5b4297244f7f68fe2
+size 10219
diff --git a/models/smi_ted/finetune/moleculenet/clintox/train.csv b/models/smi_ted/finetune/moleculenet/clintox/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..ff58b106a4761c714aae3c31c11ea210e1534d5b
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/clintox/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04bbee4a0d7fb4942292c9581f318909d06508d529a4a3a76590e6749417c1a7
+size 74357
diff --git a/models/smi_ted/finetune/moleculenet/clintox/valid.csv b/models/smi_ted/finetune/moleculenet/clintox/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..efface840e8b99c44772e26ca67fac655d0e5a8d
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/clintox/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3e2b9ab566ffc184c0590002bfbd6a42e6522209e6d6271968262844dde2905
+size 10255
diff --git a/models/smi_ted/finetune/moleculenet/esol/test.csv b/models/smi_ted/finetune/moleculenet/esol/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..835d35a8d39a5c355db39ae890b81598e5b3bc7b
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/esol/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7da41a7eab447fdfd163292b4a5eb8ef09a747fc82b0f1cc5c468e46b1b2ef5a
+size 9999
diff --git a/models/smi_ted/finetune/moleculenet/esol/train.csv b/models/smi_ted/finetune/moleculenet/esol/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..4c3e49f99bd860f679a8d5006776f44051c2528d
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/esol/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:784ba31de05a43ecab98260c94a47e2c807f4d65c0f93d9a88fbd962515976c5
+size 77154
diff --git a/models/smi_ted/finetune/moleculenet/esol/valid.csv b/models/smi_ted/finetune/moleculenet/esol/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..6aa8495439476ef54e8bb536e7943c697b08f907
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/esol/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc30e7fa1f774e27ed56de7cfd77e21f07a5a2c38fcc6d928c0084a9a99181e5
+size 9892
diff --git a/models/smi_ted/finetune/moleculenet/freesolv/test.csv b/models/smi_ted/finetune/moleculenet/freesolv/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..08c43b5d60425f5e337889df1a07a197052301b4
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/freesolv/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8212c391ccbff3722a11d1bd3752b3a9dd187f2a7b33f8b9d2d594950b188d7
+size 3223
diff --git a/models/smi_ted/finetune/moleculenet/freesolv/train.csv b/models/smi_ted/finetune/moleculenet/freesolv/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..0baf09f23bc16c90fac16b6e45714122c2af568f
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/freesolv/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3b781e5d03dbd7d272347288161f92e8e66c628da50e3e2bc06de12225de22d
+size 25053
diff --git a/models/smi_ted/finetune/moleculenet/freesolv/valid.csv b/models/smi_ted/finetune/moleculenet/freesolv/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..6b384a7841a93e57505088bf2dd643aaba76b091
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/freesolv/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3b35d9c13a02291eefe85bd4b048ccc28f5326a3b018beb937aba12067b072d2
+size 3151
diff --git a/models/smi_ted/finetune/moleculenet/hiv/test.csv b/models/smi_ted/finetune/moleculenet/hiv/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..26c91f8f64c0df5fd65a6dc8a3e19990cf0feae6
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/hiv/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e86ca708a331966f6e7b06621a2e221a9f6ce45f0141e6cbe919fd64ec50fc7
+size 213176
diff --git a/models/smi_ted/finetune/moleculenet/hiv/train.csv b/models/smi_ted/finetune/moleculenet/hiv/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..8a61257627b0d1088bb89c2d8c7c75d5c7cd27da
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/hiv/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c289700d093d7ccbe55a583ad5cb3a670df931a19283ea66880413ed398358ff
+size 1685863
diff --git a/models/smi_ted/finetune/moleculenet/hiv/valid.csv b/models/smi_ted/finetune/moleculenet/hiv/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..ff00c5124f70c15ba58ef905521c02e3bfbc8295
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/hiv/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33dd9321f709fb4fbc4545b1bfdc641eaebc410f6f698b9ed331678c5b3c3514
+size 212529
diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/test.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..222b1a5641c3d943aeefa1978086979ce88b5e25
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0947b182a1ba6b783fdca9fd01146cbe1e7bdf28d535e75765fda11a6b9a7458
+size 1541270
diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/toxicity-prediction.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/toxicity-prediction.csv
new file mode 100644
index 0000000000000000000000000000000000000000..db36447e717b0c1889742e89a7459f149532ea1c
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/toxicity-prediction.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afeaf75aebdb67f18aeab58646ef0a31ae3b2c73f3d621afe3b648ba85990210
+size 7843582
diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/train.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..862b852596832a991b593753ab21a19684df13ca
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ee6370ec81777620a59316995e15f259c93bb52511d43756db1bb744d453485
+size 4587490
diff --git a/models/smi_ted/finetune/moleculenet/ldtoxdb/valid.csv b/models/smi_ted/finetune/moleculenet/ldtoxdb/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..10e7d231ddb74b3c42ef3ef604282eb644bc33b7
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/ldtoxdb/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bde2141626dcc6e4d6e5cf78242eca4c1335724b15a297c86ce2ad36fbaf4c4c
+size 1525896
diff --git a/models/smi_ted/finetune/moleculenet/lipophilicity/test.csv b/models/smi_ted/finetune/moleculenet/lipophilicity/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..7313b2bd6d643e0e71c3eae315198e8613f26a01
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/lipophilicity/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:82a4f29bc409667a655ea3a7cddcdf74d8066150b15ae5074319ad6747bccfff
+size 28696
diff --git a/models/smi_ted/finetune/moleculenet/lipophilicity/train.csv b/models/smi_ted/finetune/moleculenet/lipophilicity/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..b749481478b42a22eae4fcc2cdc0e09bfa4ddbc5
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/lipophilicity/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebd15220de24d82242b6a0b4ddbd985c9728b8e4797dcf20500655cb17338f36
+size 228704
diff --git a/models/smi_ted/finetune/moleculenet/lipophilicity/valid.csv b/models/smi_ted/finetune/moleculenet/lipophilicity/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..281520d34cf818d3507078f9004e866fe8f7cbf5
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/lipophilicity/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:79a7682a54e0c37072dc43a5787cd77a40047011e04be204f1b961501be41613
+size 28318
diff --git a/models/smi_ted/finetune/moleculenet/muv/test.csv b/models/smi_ted/finetune/moleculenet/muv/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..c09b321e7bb7809aea99c1db1a1611dc7487ce55
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/muv/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5167910e2919d94164b3250266f846ef468aea7be1dea43698d08fa91da4933a
+size 721037
diff --git a/models/smi_ted/finetune/moleculenet/muv/train.csv b/models/smi_ted/finetune/moleculenet/muv/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..55db4ad896a618d67ba1044d3d97a42a17570496
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/muv/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdc1c542f08aef281fb6c4b8727a92d1f8bfe94e4a370b9240dde03cc866cead
+size 5781901
diff --git a/models/smi_ted/finetune/moleculenet/muv/valid.csv b/models/smi_ted/finetune/moleculenet/muv/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..a605a98f25564f1d7dbf26936b21d57892e68113
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/muv/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0743918c3610c4bceada0f529f7ffac7a196c50b355d92da3e06bbb9dac56ffe
+size 723580
diff --git a/models/smi_ted/finetune/moleculenet/qm8/qm8.csv b/models/smi_ted/finetune/moleculenet/qm8/qm8.csv
new file mode 100644
index 0000000000000000000000000000000000000000..1b81b402728c3e075c1ba6bdee6734df4085a7ae
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm8/qm8.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e70e56805edb2f0796b64f96af9b53dd8cca408d775612d47123f7d2da7d61d
+size 4719270
diff --git a/models/smi_ted/finetune/moleculenet/qm8/test.csv b/models/smi_ted/finetune/moleculenet/qm8/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..8c6ad38a79f1716a7171c0822d6fca8d5bf6e2c8
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm8/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:418fd6afa3db6a219050bd36df002620bce787a2890debab5e63b0829c879914
+size 471657
diff --git a/models/smi_ted/finetune/moleculenet/qm8/train.csv b/models/smi_ted/finetune/moleculenet/qm8/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..5c21bbb766e815373251d74888daaa61652ce414
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm8/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:079af359ffdcba408646a556ea862ded8f483381af320e36d4981dcbe28b849b
+size 3770636
diff --git a/models/smi_ted/finetune/moleculenet/qm8/valid.csv b/models/smi_ted/finetune/moleculenet/qm8/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..2c061eeadadccf63bc74614ab5bd72917afecb2b
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm8/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e032fee62d0117743a14095fffe07223b6f8c514a1961298388a6c6bd272fd5
+size 470821
diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9.csv
new file mode 100644
index 0000000000000000000000000000000000000000..0c1cf60535385093f5edd4c17a8605429cce67d6
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm9/qm9.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e668f8c34e4bc392a90d417a50a5eed3b64b842a817a633024bdc054c68ccb4
+size 29856825
diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9_small_test.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..d835a0e30540d12c3be00088b45df27ac382cf99
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56872786be70d65eb43e943417a293b5c64efb015da6e1cfa3cdd6bc06d8a057
+size 7255
diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9_small_train.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..d835a0e30540d12c3be00088b45df27ac382cf99
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56872786be70d65eb43e943417a293b5c64efb015da6e1cfa3cdd6bc06d8a057
+size 7255
diff --git a/models/smi_ted/finetune/moleculenet/qm9/qm9_small_valid.csv b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..d835a0e30540d12c3be00088b45df27ac382cf99
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm9/qm9_small_valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56872786be70d65eb43e943417a293b5c64efb015da6e1cfa3cdd6bc06d8a057
+size 7255
diff --git a/models/smi_ted/finetune/moleculenet/qm9/test.csv b/models/smi_ted/finetune/moleculenet/qm9/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..21f1f45d52b056bbc1f18dab228f8528b2e324cc
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm9/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:352e17f5061840e0cdcffdc2e86d5c483ac5aa31a8e8feb1916825247e0ad323
+size 2986085
diff --git a/models/smi_ted/finetune/moleculenet/qm9/train.csv b/models/smi_ted/finetune/moleculenet/qm9/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..a04e7e9390a0040d697ecedc1b6cd54f58b166f1
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm9/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f0d2d1faa91c040ba845dbf8375ab1351d14d292b7840f13675afe50658a2ed
+size 24186523
diff --git a/models/smi_ted/finetune/moleculenet/qm9/valid.csv b/models/smi_ted/finetune/moleculenet/qm9/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..0b68edafa00a14e066d5dee89cee71f82d409a9d
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/qm9/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f701d74d3e19b1bf2daeca0043a8a4403e1ba794831682509d45fa54a54587d1
+size 2687631
diff --git a/models/smi_ted/finetune/moleculenet/sider/test.csv b/models/smi_ted/finetune/moleculenet/sider/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..0e586ffe3cb5e7ecaa62067328195eb33954c1c1
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/sider/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03f44a09ac140628293a36e4ac6d23a719058b9956cfb07f5db7923e527e187f
+size 18568
diff --git a/models/smi_ted/finetune/moleculenet/sider/train.csv b/models/smi_ted/finetune/moleculenet/sider/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..9e89e340142dde4ca670f2e7c760b26081d4a0c9
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/sider/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c71e635cccf17c173fedfca7d91fa35c66c4d95f1c558d92e67c1652b831fb75
+size 147151
diff --git a/models/smi_ted/finetune/moleculenet/sider/valid.csv b/models/smi_ted/finetune/moleculenet/sider/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..bb6f83f3067d535eea3c6cbe1b5fab4818747e8b
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/sider/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf790d7e965f90710e45c4f637352d1349c4fa420c9c3cb8e3bab4c86b38755c
+size 19691
diff --git a/models/smi_ted/finetune/moleculenet/tox21/test.csv b/models/smi_ted/finetune/moleculenet/tox21/test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..105b35dc005c562395f8a80803aa759f88b37d70
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/tox21/test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06e9af48940e5eba55ad20229ba3d0f2c1c5110007aa16790cc86df9b0e5de14
+size 53905
diff --git a/models/smi_ted/finetune/moleculenet/tox21/tox21.csv b/models/smi_ted/finetune/moleculenet/tox21/tox21.csv
new file mode 100644
index 0000000000000000000000000000000000000000..a31a8e23869b39932128dba54c716abca10b47a9
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/tox21/tox21.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1689278aa402ef8da840be1126c547253338d390cc8526714910a3b2a39fa1c9
+size 536070
diff --git a/models/smi_ted/finetune/moleculenet/tox21/train.csv b/models/smi_ted/finetune/moleculenet/tox21/train.csv
new file mode 100644
index 0000000000000000000000000000000000000000..4f9ae089aafddb9bea9701c03114ca64ca722d10
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/tox21/train.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f5c366c2d80a6982fd95bdb37f41631c43a158b7d165e201b74ce8fe68c0a03
+size 416358
diff --git a/models/smi_ted/finetune/moleculenet/tox21/valid.csv b/models/smi_ted/finetune/moleculenet/tox21/valid.csv
new file mode 100644
index 0000000000000000000000000000000000000000..922ee60b8888e464173ac9ede8f4fc6a15e16083
--- /dev/null
+++ b/models/smi_ted/finetune/moleculenet/tox21/valid.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d1e23f3b582e66fdc74c7b89757bd832208f7183c3c5fcbabf6d45e321ffed7
+size 55019
diff --git a/models/smi_ted/finetune/smi_ted_large/bace/run_finetune_bace.sh b/models/smi_ted/finetune/smi_ted_large/bace/run_finetune_bace.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8d7a5cfe4b6f9af9664e99b48bfa117d347ad385
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/bace/run_finetune_bace.sh
@@ -0,0 +1,25 @@
+python ../../finetune_classification.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/bace' \
+ --dataset_name bace \
+ --measure_name 'Class' \
+ --checkpoints_folder './checkpoints_bace' \
+ --loss_fn 'crossentropy' \
+ --target_metric 'roc-auc' \
+ --n_output 2 \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/bbbp/run_finetune_bbbp.sh b/models/smi_ted/finetune/smi_ted_large/bbbp/run_finetune_bbbp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bc1d0b7d09eb2ad005700dc2c9d55c49bd714e1a
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/bbbp/run_finetune_bbbp.sh
@@ -0,0 +1,25 @@
+python ../../finetune_classification.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/bbbp' \
+ --dataset_name bbbp \
+ --measure_name 'p_np' \
+ --checkpoints_folder './checkpoints_bbbp' \
+ --loss_fn 'crossentropy' \
+ --target_metric 'roc-auc' \
+ --n_output 2 \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/bert_vocab_curated.txt b/models/smi_ted/finetune/smi_ted_large/bert_vocab_curated.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/bert_vocab_curated.txt
@@ -0,0 +1,2393 @@
+
+
+
+
+C
+c
+(
+)
+1
+O
+N
+2
+=
+n
+3
+[C@H]
+[C@@H]
+F
+S
+4
+Cl
+-
+o
+s
+[nH]
+#
+/
+Br
+[C@]
+[C@@]
+[N+]
+[O-]
+5
+\
+.
+I
+6
+[S@]
+[S@@]
+P
+[N-]
+[Si]
+7
+[n+]
+[2H]
+8
+[NH+]
+B
+9
+[C-]
+[Na+]
+[Cl-]
+[c-]
+[CH]
+%10
+[NH2+]
+[P+]
+[B]
+[I-]
+%11
+[CH2-]
+[O+]
+[NH3+]
+[C]
+[Br-]
+[IH2]
+[S-]
+[cH-]
+%12
+[nH+]
+[B-]
+[K+]
+[Sn]
+[Se]
+[CH-]
+[HH]
+[Y]
+[n-]
+[CH3-]
+[SiH]
+[S+]
+%13
+[SiH2]
+[Li+]
+[NH-]
+%14
+[Na]
+[CH2]
+[O-2]
+[U+2]
+[W]
+[Al]
+[P@]
+[Fe+2]
+[PH+]
+%15
+[Cl+3]
+[Zn+2]
+[Ir]
+[Mg+2]
+[Pt+2]
+[OH2+]
+[As]
+[Fe]
+[OH+]
+[Zr+2]
+[3H]
+[Ge]
+[SiH3]
+[OH-]
+[NH4+]
+[Cu+2]
+[P@@]
+p
+[Pt]
+%16
+[Ca+2]
+[Zr]
+[F-]
+[C+]
+[Ti]
+[P-]
+[V]
+[se]
+[U]
+[O]
+[Ni+2]
+[Zn]
+[Co]
+[Ni]
+[Pd+2]
+[Cu]
+%17
+[Cu+]
+[Te]
+[H+]
+[CH+]
+[Li]
+[Pd]
+[Mo]
+[Ru+2]
+[o+]
+[Re]
+[SH+]
+%18
+[Ac]
+[Cr]
+[NH2-]
+[K]
+[13CH2]
+[c]
+[Zr+4]
+[Tl]
+[13C]
+[Mn]
+[N@+]
+[Hg]
+[Rh]
+[Ti+4]
+[Sb]
+[Co+2]
+[Ag+]
+[Ru]
+%19
+[N@@+]
+[Ti+2]
+[Al+3]
+[Pb]
+[I+]
+[18F]
+[s+]
+[Rb+]
+[Ba+2]
+[H-]
+[Fe+3]
+[Ir+3]
+[13cH]
+%20
+[AlH2]
+[Au+]
+[13c]
+[SH2+]
+[Sn+2]
+[Mn+2]
+[Si-]
+[Ag]
+[N]
+[Bi]
+%21
+[In]
+[CH2+]
+[Y+3]
+[Ga]
+%22
+[Co+3]
+[Au]
+[13CH3]
+[Mg]
+[Cs+]
+[W+2]
+[Hf]
+[Zn+]
+[Se-]
+[S-2]
+[Ca]
+[pH]
+[ClH+]
+[Ti+3]
+%23
+[Ru+]
+[SH-]
+[13CH]
+[IH+]
+[Hf+4]
+[Rf]
+[OH3+]
+%24
+[Pt+4]
+[Zr+3]
+[PH3+]
+[Sr+2]
+[Cd+2]
+[Cd]
+%25
+[Os]
+[BH-]
+[Sn+4]
+[Cr+3]
+[Ru+3]
+[PH2+]
+[Rh+2]
+[V+2]
+%26
+[Gd+3]
+[Pb+2]
+[PH]
+[Hg+]
+[Mo+2]
+[AlH]
+[Sn+]
+%27
+[Pd+]
+b
+[Rh+3]
+[Hg+2]
+[15NH]
+[14C]
+%28
+[Mn+3]
+[Si+]
+[SeH]
+[13C@H]
+[NH]
+[Ga+3]
+[SiH-]
+[13C@@H]
+[Ce]
+[Au+3]
+[Bi+3]
+[15N]
+%29
+[BH3-]
+[14cH]
+[Ti+]
+[Gd]
+[cH+]
+[Cr+2]
+[Sb-]
+%30
+[Be+2]
+[Al+]
+[te]
+[11CH3]
+[Sm]
+[Pr]
+[La]
+%31
+[Al-]
+[Ta]
+[125I]
+[BH2-]
+[Nb]
+[Si@]
+%32
+[14c]
+[Sb+3]
+[Ba]
+%33
+[Os+2]
+[Si@@]
+[La+3]
+[15n]
+[15NH2]
+[Nd+3]
+%34
+[14CH2]
+[18O]
+[Nd]
+[GeH]
+[Ni+3]
+[Eu]
+[Dy+3]
+[Sc]
+%36
+[Se-2]
+[As+]
+%35
+[AsH]
+[Tb]
+[Sb+5]
+[Se+]
+[Ce+3]
+[c+]
+[In+3]
+[SnH]
+[Mo+4]
+%37
+[V+4]
+[Eu+3]
+[Hf+2]
+%38
+[Pt+]
+[p+]
+[123I]
+[Tl+]
+[Sm+3]
+%39
+[Yb+3]
+%40
+[Yb]
+[Os+]
+%41
+[10B]
+[Sc+3]
+[Al+2]
+%42
+[Sr]
+[Tb+3]
+[Po]
+[Tc]
+[PH-]
+[AlH3]
+[Ar]
+[U+4]
+[SnH2]
+[Cl+2]
+[si]
+[Fe+]
+[14CH3]
+[U+3]
+[Cl+]
+%43
+[GeH2]
+%44
+[Er+3]
+[Mo+3]
+[I+2]
+[Fe+4]
+[99Tc]
+%45
+[11C]
+%46
+[SnH3]
+[S]
+[Te+]
+[Er]
+[Lu+3]
+[11B]
+%47
+%48
+[P]
+[Tm]
+[Th]
+[Dy]
+[Pr+3]
+[Ta+5]
+[Nb+5]
+[Rb]
+[GeH3]
+[Br+2]
+%49
+[131I]
+[Fm]
+[Cs]
+[BH4-]
+[Lu]
+[15nH]
+%50
+[Ru+6]
+[b-]
+[Ho]
+[Th+4]
+[Ru+4]
+%52
+[14CH]
+%51
+[Cr+6]
+[18OH]
+[Ho+3]
+[Ce+4]
+[Bi+2]
+[Co+]
+%53
+[Yb+2]
+[Fe+6]
+[Be]
+%54
+[SH3+]
+[Np]
+[As-]
+%55
+[14C@@H]
+[Ir+2]
+[GaH3]
+[p-]
+[GeH4]
+[Sn+3]
+[Os+4]
+%56
+[14C@H]
+[sH+]
+[19F]
+[Eu+2]
+[TlH]
+%57
+[Cr+4]
+%58
+[B@@-]
+[SiH+]
+[At]
+[Am]
+[Fe+5]
+[AsH2]
+[Si+4]
+[B@-]
+[Pu]
+[SbH]
+[P-2]
+[Tm+3]
+*
+%59
+[se+]
+[IH-]
+%60
+[oH+]
+[1H]
+[15N+]
+[124I]
+[S@@+]
+[P-3]
+[H]
+[IH2+]
+[TeH]
+[Xe]
+[PH4+]
+[Cr+]
+[Cm]
+[I+3]
+%61
+[Nb+2]
+[Ru+5]
+%62
+[Ta+2]
+[Tc+4]
+[CH3+]
+[Pm]
+[Si@H]
+[No]
+%63
+[Cr+5]
+[Th+2]
+[Zn-2]
+[13C@]
+[Lr]
+%64
+[99Tc+3]
+%65
+[13C@@]
+%66
+[Fe-]
+[17O]
+[siH]
+[Sb+]
+[OH]
+[IH]
+[11CH2]
+[Cf]
+[SiH2+]
+[Gd+2]
+[In+]
+[Si@@H]
+[Mn+]
+[99Tc+4]
+[Ga-]
+%67
+[S@+]
+[Ge+4]
+[Tl+3]
+[16OH]
+%68
+[2H-]
+[Ra]
+[si-]
+[NiH2]
+[P@@H]
+[Rh+]
+[12C]
+[35S]
+[32P]
+[SiH2-]
+[AlH2+]
+[16O]
+%69
+[BiH]
+[BiH2]
+[Zn-]
+[BH]
+[Tc+3]
+[Ir+]
+[Ni+]
+%70
+[InH2]
+[InH]
+[Nb+3]
+[PbH]
+[Bi+]
+%71
+[As+3]
+%72
+[18O-]
+[68Ga+3]
+%73
+[Pa]
+[76Br]
+[Tc+5]
+[pH+]
+[64Cu+2]
+[Ru+8]
+%74
+[PH2-]
+[Si+2]
+[17OH]
+[RuH]
+[111In+3]
+[AlH+]
+%75
+%76
+[W+]
+[SbH2]
+[PoH]
+[Ru-]
+[XeH]
+[Tc+2]
+[13C-]
+[Br+]
+[Pt-2]
+[Es]
+[Cu-]
+[Mg+]
+[3HH]
+[P@H]
+[ClH2+]
+%77
+[SH]
+[Au-]
+[2HH]
+%78
+[Sn-]
+[11CH]
+[PdH2]
+0
+[Os+6]
+%79
+[Mo+]
+%80
+[al]
+[PbH2]
+[64Cu]
+[Cl]
+[12CH3]
+%81
+[Tc+7]
+[11c]
+%82
+[Li-]
+[99Tc+5]
+[He]
+[12c]
+[Kr]
+[RuH+2]
+[35Cl]
+[Pd-2]
+[GaH2]
+[4H]
+[Sg]
+[Cu-2]
+[Br+3]
+%83
+[37Cl]
+[211At]
+[IrH+2]
+[Mt]
+[Ir-2]
+[In-]
+[12cH]
+[12CH2]
+[RuH2]
+[99Tc+7]
+%84
+[15n+]
+[ClH2+2]
+[16N]
+[111In]
+[Tc+]
+[Ru-2]
+[12CH]
+[si+]
+[Tc+6]
+%85
+%86
+[90Y]
+[Pd-]
+[188Re]
+[RuH+]
+[NiH]
+[SiH3-]
+[14n]
+[CH3]
+[14N]
+[10BH2]
+%88
+%89
+%90
+[34S]
+[77Br]
+[GaH]
+[Br]
+[Ge@]
+[B@@H-]
+[CuH]
+[SiH4]
+[3H-]
+%87
+%91
+%92
+[67Cu]
+[I]
+[177Lu]
+[ReH]
+[67Ga+3]
+[Db]
+[177Lu+3]
+[AlH2-]
+[Si+3]
+[Ti-2]
+[RuH+3]
+[al+]
+[68Ga]
+[2H+]
+[B@H-]
+[WH2]
+[OsH]
+[Ir-3]
+[AlH-]
+[Bk]
+[75Se]
+[14C@]
+[Pt-]
+[N@@H+]
+[Nb-]
+[13NH2]
+%93
+[186Re]
+[Tb+4]
+[PtH]
+[IrH2]
+[Hg-2]
+[AlH3-]
+[PdH+]
+[Md]
+[RhH+2]
+[11cH]
+[Co-2]
+[15N-]
+[ZrH2]
+%94
+[Hg-]
+[127I]
+[AsH2+]
+[MoH2]
+[Te+4]
+[14C@@]
+[As+5]
+[SnH+3]
+[Ge@@]
+[6Li+]
+[WH]
+[Ne]
+[14NH2]
+[14NH]
+[12C@@H]
+[Os+7]
+[RhH]
+[Al-3]
+[SnH+]
+[15NH3+]
+[Zr+]
+[197Hg+]
+%95
+%96
+[90Y+3]
+[Os-2]
+[98Tc+5]
+[15NH3]
+[bH-]
+[33P]
+[Zr-2]
+[15O]
+[Rh-]
+[PbH3]
+[PH2]
+[Ni-]
+[CuH+]
+%97
+%98
+%99
+[Os+5]
+[PtH+]
+[ReH4]
+[16NH]
+[82Br]
+[W-]
+[18F-]
+[15NH4+]
+[Se+4]
+[SeH-]
+[SH4]
+[67Cu+2]
+[12C@H]
+[AsH3]
+[HgH]
+[10B-]
+[99Tc+6]
+[117Sn+4]
+[Te@]
+[P@+]
+[35SH]
+[SeH+]
+[Ni-2]
+[Al-2]
+[TeH2]
+[Bh]
+[99Tc+2]
+[Os+8]
+[PH-2]
+[7Li+]
+[14nH]
+[AlH+2]
+[18FH]
+[SnH4]
+[18O-2]
+[IrH]
+[13N]
+[Te@@]
+[Rh-3]
+[15NH+]
+[AsH3+]
+[SeH2]
+[AsH+]
+[CoH2]
+[16NH2]
+[AsH-]
+[203Hg+]
+[P@@+]
+[166Ho+3]
+[60Co+3]
+[13CH2-]
+[SeH2+]
+[75Br]
+[TlH2]
+[80Br]
+[siH+]
+[Ca+]
+[153Sm+3]
+[PdH]
+[225Ac]
+[13CH3-]
+[AlH4-]
+[FeH]
+[13CH-]
+[14C-]
+[11C-]
+[153Sm]
+[Re-]
+[te+]
+[13CH4]
+[ClH+2]
+[8CH2]
+[99Mo]
+[ClH3+3]
+[SbH3]
+[25Mg+2]
+[16N+]
+[SnH2+]
+[PH4]
+[11C@H]
+[122I]
+[Re-2]
+[RuH2+2]
+[ZrH]
+[Bi-]
+[Pr+]
+[Rn]
+[Fr]
+[36Cl]
+[18o]
+[YH]
+[79Br]
+[121I]
+[113In+3]
+[InH4-]
+[TaH]
+[RhH2]
+[Ta-]
+[67Ga]
+[ZnH+]
+[SnH2-]
+[OsH2]
+[16F]
+[FeH2]
+[14O]
+[PbH2+2]
+[BH2]
+[6H]
+[125Te]
+[197Hg]
+[TaH2]
+[TaH3]
+[76As]
+[Nb-2]
+[14N+]
+[125I-]
+[33S]
+[IH2+2]
+[NH2]
+[PtH2]
+[MnH]
+[19C]
+[17F]
+[1H-]
+[SnH4+2]
+[Mn-2]
+[15NH2+]
+[TiH2]
+[ReH7]
+[Cd-2]
+[Fe-3]
+[SH2]
+[17O-]
+[siH-]
+[CoH+]
+[VH]
+[10BH]
+[Ru-3]
+[13O]
+[5H]
+[CoH]
+[PH5]
+[15n-]
+[153Gd]
+[12C@]
+[11CH3-]
+[IrH3]
+[RuH3]
+[74Se]
+[Se@]
+[Hf+]
+[77Se]
+[166Ho]
+[59Fe+2]
+[203Hg]
+[18OH-]
+[8CH]
+[12C@@]
+[11CH4]
+[15C]
+[249Cf]
+[PbH4]
+[64Zn]
+[PH3]
+[99Tc+]
+[14c-]
+[149Pm]
+[IrH4]
+[Se@@]
+[13OH]
+[14CH3-]
+[28Si]
+[Rh-2]
+[Fe-2]
+[131I-]
+[51Cr]
+[62Cu+2]
+[81Br]
+[121Sb]
+[7Li]
+[89Zr+4]
+[SbH3+]
+[11C@@H]
+[98Tc]
+[59Fe+3]
+[BiH2+]
+[SbH+]
+[TiH]
+[14NH3]
+[15OH]
+[119Sn]
+[201Hg]
+[MnH+]
+[201Tl]
+[51Cr+3]
+[123I-]
+[MoH]
+[AlH6-3]
+[MnH2]
+[WH3]
+[213Bi+3]
+[SnH2+2]
+[123IH]
+[13CH+]
+[Zr-]
+[74As]
+[13C+]
+[32P+]
+[KrH]
+[SiH+2]
+[ClH3+2]
+[13NH]
+[9CH2]
+[ZrH2+2]
+[87Sr+2]
+[35s]
+[239Pu]
+[198Au]
+[241Am]
+[203Hg+2]
+[V+]
+[YH2]
+[SH5]
+[195Pt]
+[203Pb]
+[RuH4]
+[ThH2]
+[AuH]
+[66Ga+3]
+[11B-]
+[F]
+[24Na+]
+[85Sr+2]
+[201Tl+]
+[14CH4]
+[32S]
+[TeH2+]
+[ClH2+3]
+[AgH]
+[Ge@H]
+[44Ca+2]
+[Os-]
+[31P]
+[15nH+]
+[SbH4]
+[TiH+]
+[Ba+]
+[57Co+2]
+[Ta+]
+[125IH]
+[77As]
+[129I]
+[Fe-4]
+[Ta-2]
+[19O]
+[12O]
+[BiH3]
+[237Np]
+[252Cf]
+[86Y]
+[Cr-2]
+[89Y]
+[195Pt+2]
+[si+2]
+[58Fe+2]
+[Hs]
+[S@@H]
+[OsH6]
+[GdH2]
+[IH3]
+[8CH4]
+[164Dy+3]
+[47Ca+2]
+[57Co]
+[NbH2]
+[ReH2]
+[ZnH2]
+[CrH2]
+[17NH]
+[ZrH3]
+[RhH3]
+[12C-]
+[18O+]
+[Bi-2]
+[ClH4+3]
+[Ni-3]
+[Ag-]
+[111In-]
+[Mo-2]
+[55Fe+3]
+[204Hg+]
+[35Cl-]
+[211Pb]
+[75Ge]
+[8B]
+[TeH3]
+[SnH3+]
+[Zr-3]
+[28F]
+[249Bk]
+[169Yb]
+[34SH]
+[6Li]
+[94Tc]
+[197Au]
+[195Pt+4]
+[169Yb+3]
+[32Cl]
+[82Se]
+[159Gd+3]
+[213Bi]
+[CoH+2]
+[36S]
+[35P]
+[Ru-4]
+[Cr-3]
+[60Co]
+[1H+]
+[18CH2]
+[Cd-]
+[152Sm+3]
+[106Ru]
+[238Pu]
+[220Rn]
+[45Ca+2]
+[89Sr+2]
+[239Np]
+[90Sr+2]
+[137Cs+]
+[165Dy]
+[68GaH3]
+[65Zn+2]
+[89Zr]
+[BiH2+2]
+[62Cu]
+[165Dy+3]
+[238U]
+[105Rh+3]
+[70Zn]
+[12B]
+[12OH]
+[18CH]
+[17CH]
+[OsH3]
+[SbH-]
+[SH6]
+[AlH2-2]
+[42K]
+[76Br-]
+[71As]
+[NbH3]
+[ReH3]
+[OsH-]
+[WH4]
+[MoH3]
+[OsH4]
+[RuH6]
+[PtH3]
+[CuH2]
+[CoH3]
+[TiH4]
+[64Zn+2]
+[Si-2]
+[79BrH]
+[14CH2-]
+[PtH2+2]
+[Os-3]
+[29Si]
+[Ti-]
+[Se+6]
+[22Na+]
+[42K+]
+[131Cs+]
+[86Rb+]
+[134Cs+]
+[209Po]
+[208Po]
+[81Rb+]
+[203Tl+]
+[Zr-4]
+[148Sm]
+[147Sm]
+[37Cl-]
+[12CH4]
+[Ge@@H]
+[63Cu]
+[13CH2+]
+[AsH2-]
+[CeH]
+[SnH-]
+[UH]
+[9c]
+[21CH3]
+[TeH+]
+[57Co+3]
+[8BH2]
+[12BH2]
+[19BH2]
+[9BH2]
+[YbH2]
+[CrH+2]
+[208Bi]
+[152Gd]
+[61Cu]
+[115In]
+[60Co+2]
+[13NH2-]
+[120I]
+[18OH2]
+[75SeH]
+[SbH2+]
+[144Ce]
+[16n]
+[113In]
+[22nH]
+[129I-]
+[InH3]
+[32PH3]
+[234U]
+[235U]
+[59Fe]
+[82Rb+]
+[65Zn]
+[244Cm]
+[147Pm]
+[91Y]
+[237Pu]
+[231Pa]
+[253Cf]
+[127Te]
+[187Re]
+[236Np]
+[235Np]
+[72Zn]
+[253Es]
+[159Dy]
+[62Zn]
+[101Tc]
+[149Tb]
+[124I-]
+[SeH3+]
+[210Pb]
+[40K]
+[210Po]
+[214Pb]
+[218Po]
+[214Po]
+[7Be]
+[212Pb]
+[205Pb]
+[209Pb]
+[123Te]
+[202Pb]
+[72As]
+[201Pb]
+[70As]
+[73Ge]
+[200Pb]
+[198Pb]
+[66Ga]
+[73Se]
+[195Pb]
+[199Pb]
+[144Ce+3]
+[235U+2]
+[90Tc]
+[114In+3]
+[128I]
+[100Tc+]
+[82Br-]
+[191Pt+2]
+[191Pt+4]
+[193Pt+4]
+[31PH3]
+[125I+2]
+[131I+2]
+[125Te+4]
+[82Sr+2]
+[149Sm]
+[81BrH]
+[129Xe]
+[193Pt+2]
+[123I+2]
+[Cr-]
+[Co-]
+[227Th+4]
+[249Cf+3]
+[252Cf+3]
+[187Os]
+[16O-]
+[17O+]
+[16OH-]
+[98Tc+7]
+[58Co+2]
+[69Ga+3]
+[57Fe+2]
+[43K+]
+[16C]
+[52Fe+3]
+[SeH5]
+[194Pb]
+[196Pb]
+[197Pb]
+[213Pb]
+[9B]
+[19B]
+[11CH-]
+[9CH]
+[20OH]
+[25OH]
+[8cH]
+[TiH+3]
+[SnH6+3]
+[N@H+]
+[ZnH]
+[VH3]
+[52Mn+2]
+[64Ga]
+[13B]
+[216Bi]
+[117Sn+2]
+[232Th]
+[SnH+2]
+[BiH5]
+[77Kr]
+[103Cd]
+[62Ni]
+[LaH3]
+[SmH3]
+[EuH3]
+[MoH5]
+[64Ni]
+[66Zn]
+[68Zn]
+[186W]
+[FeH4]
+[MoH4]
+[HgH2]
+[15NH2-]
+[UH2]
+[204Hg]
+[GaH4-]
+[ThH4]
+[WH6]
+[PtH4]
+[VH2]
+[UH3]
+[FeH3]
+[RuH5]
+[BiH4]
+[80Br-]
+[CeH3]
+[37ClH]
+[157Gd+3]
+[205Tl]
+[203Tl]
+[62Cu+]
+[64Cu+]
+[61Cu+]
+[37SH2]
+[30Si]
+[28Al]
+[19OH2]
+[8He]
+[6He]
+[153Pm]
+[209Bi]
+[66Zn+2]
+[10CH4]
+[191Ir]
+[66Cu]
+[16O+]
+[25O]
+[10c]
+[Co-3]
+[Sn@@]
+[17OH-]
+[206Po]
+[204Po]
+[202Po]
+[201Po]
+[200Po]
+[199Po]
+[198Po]
+[197Po]
+[196Po]
+[195Po]
+[194Po]
+[193Po]
+[192Po]
+[191Po]
+[190Po]
+[217Po]
+[BiH4-]
+[TeH4]
+[222Ra]
+[62Ga]
+[39Ar]
+[144Sm]
+[58Fe]
+[153Eu]
+[85Rb]
+[171Yb]
+[172Yb]
+[114Cd]
+[51Fe]
+[142Ce]
+[207Tl]
+[92Mo]
+[115Sn]
+[140Ce]
+[202Hg]
+[180W]
+[182W]
+[183W]
+[184W]
+[96Mo]
+[47Ti]
+[111Cd]
+[143Nd]
+[145Nd]
+[126Te]
+[128Te]
+[130Te]
+[185Re]
+[97Mo]
+[98Mo]
+[183Re]
+[52V]
+[80Se]
+[87Kr]
+[137Xe]
+[196Au]
+[146Ce]
+[88Kr]
+[51Ti]
+[138Xe]
+[112Cd]
+[116Sn]
+[120Sn]
+[28SiH3]
+[35S-]
+[15NH-]
+[13CH3+]
+[34S+]
+[34s]
+[SiH4-]
+[100Tc+5]
+[NiH2+2]
+[239Th]
+[186Lu]
+[AuH3]
+[I@@-]
+[XeH2]
+[B+]
+[16CH2]
+[8C]
+[TaH5]
+[FeH4-]
+[19C@H]
+[10NH]
+[FeH6-3]
+[22CH]
+[25N]
+[25N+]
+[25N-]
+[21CH2]
+[18cH]
+[113I]
+[ScH3]
+[30PH3]
+[43Ca+2]
+[41Ca+2]
+[106Cd]
+[122Sn]
+[18CH3]
+[58Co+3]
+[98Tc+4]
+[70Ge]
+[76Ge]
+[108Cd]
+[116Cd]
+[130Xe]
+[94Mo]
+[124Sn]
+[186Os]
+[188Os]
+[190Os]
+[192Os]
+[106Pd]
+[110Pd]
+[120Te]
+[132Ba]
+[134Ba]
+[136Ba]
+[136Ce]
+[138Ce]
+[156Dy]
+[158Dy]
+[160Dy]
+[163Dy]
+[162Er]
+[164Er]
+[167Er]
+[176Hf]
+[26Mg]
+[144Nd]
+[150Nd]
+[41K]
+[46Ti]
+[48Ti]
+[49Ti]
+[50Ti]
+[170Yb]
+[173Yb]
+[91Zr]
+[92Zr]
+[96Zr]
+[34S-]
+[CuH2-]
+[38Cl]
+[25Mg]
+[51V]
+[93Nb]
+[95Mo]
+[45Sc]
+[123Sb]
+[139La]
+[9Be]
+[99Y+3]
+[99Y]
+[156Ho]
+[67Zn]
+[144Ce+4]
+[210Tl]
+[42Ca]
+[54Fe]
+[193Ir]
+[92Nb]
+[141Cs]
+[52Cr]
+[35ClH]
+[46Ca]
+[139Cs]
+[65Cu]
+[71Ga]
+[60Ni]
+[16NH3]
+[148Nd]
+[72Ge]
+[161Dy]
+[49Ca]
+[43Ca]
+[8Be]
+[48Ca]
+[44Ca]
+[120Xe]
+[80Rb]
+[215At]
+[180Re]
+[146Sm]
+[19Ne]
+[74Kr]
+[134La]
+[76Kr]
+[219Fr]
+[121Xe]
+[220Fr]
+[216At]
+[223Ac]
+[218At]
+[37Ar]
+[135I]
+[110Cd]
+[94Tc+7]
+[86Y+3]
+[135I-]
+[15O-2]
+[151Eu+3]
+[161Tb+3]
+[197Hg+2]
+[109Cd+2]
+[191Os+4]
+[170Tm+3]
+[205Bi+3]
+[233U+4]
+[126Sb+3]
+[127Sb+3]
+[132Cs+]
+[136Eu+3]
+[136Eu]
+[125Sn+4]
+[175Yb+3]
+[100Mo]
+[22Ne]
+[13c-]
+[13NH4+]
+[17C]
+[9C]
+[31S]
+[31SH]
+[133I]
+[126I]
+[36SH]
+[30S]
+[32SH]
+[19CH2]
+[19c]
+[18c]
+[15F]
+[10C]
+[RuH-]
+[62Zn+2]
+[32ClH]
+[33ClH]
+[78BrH]
+[12Li+]
+[12Li]
+[233Ra]
+[68Ge+4]
+[44Sc+3]
+[91Y+3]
+[106Ru+3]
+[PoH2]
+[AtH]
+[55Fe]
+[233U]
+[210PoH2]
+[230Th]
+[228Th]
+[222Rn]
+[35SH2]
+[227Th]
+[192Ir]
+[133Xe]
+[81Kr]
+[95Zr]
+[240Pu]
+[54Mn]
+[103Ru]
+[95Nb]
+[109Cd]
+[141Ce]
+[85Kr]
+[110Ag]
+[58Co]
+[241Pu]
+[234Th]
+[140La]
+[63Ni]
+[152Eu]
+[132IH]
+[226Rn]
+[154Eu]
+[36ClH]
+[228Ac]
+[155Eu]
+[106Rh]
+[243Am]
+[227Ac]
+[243Cm]
+[236U]
+[144Pr]
+[232U]
+[32SH2]
+[88Y]
+[82BrH]
+[135IH]
+[242Cm]
+[115Cd]
+[242Pu]
+[46Sc]
+[56Mn]
+[234Pa]
+[41Ar]
+[147Nd]
+[187W]
+[151Sm]
+[59Ni]
+[233Pa]
+[52Mn]
+[94Nb]
+[219Rn]
+[236Pu]
+[13NH3]
+[93Zr]
+[51Cr+6]
+[TlH3]
+[123Xe]
+[160Tb]
+[170Tm]
+[182Ta]
+[175Yb]
+[93Mo]
+[143Ce]
+[191Os]
+[126IH]
+[48V]
+[113Cd]
+[47Sc]
+[181Hf]
+[185W]
+[143Pr]
+[191Pt]
+[181W]
+[33PH3]
+[97Ru]
+[97Tc]
+[111Ag]
+[169Er]
+[107Pd]
+[103Ru+2]
+[34SH2]
+[137Ce]
+[242Am]
+[117SnH2]
+[57Ni]
+[239U]
+[60Cu]
+[250Cf]
+[193Au]
+[69Zn]
+[55Co]
+[139Ce]
+[127Xe]
+[159Gd]
+[56Co]
+[177Hf]
+[244Pu]
+[38ClH]
+[142Pr]
+[199Hg]
+[179Hf]
+[178Hf]
+[237U]
+[156Eu]
+[157Eu]
+[105Ru]
+[171Tm]
+[199Au]
+[155Sm]
+[80BrH]
+[108Ag]
+[128IH]
+[48Sc]
+[45Ti]
+[176Lu]
+[121SnH2]
+[148Pm]
+[57Fe]
+[10BH3]
+[96Tc]
+[133IH]
+[143Pm]
+[105Rh]
+[130IH]
+[134IH]
+[131IH]
+[71Zn]
+[105Ag]
+[97Zr]
+[235Pu]
+[231Th]
+[109Pd]
+[93Y]
+[190Ir]
+[135Xe]
+[53Mn]
+[134Ce]
+[234Np]
+[240Am]
+[246Cf]
+[240Cm]
+[241Cm]
+[226Th]
+[39ClH]
+[229Th]
+[245Cm]
+[240U]
+[240Np]
+[249Cm]
+[243Pu]
+[145Pm]
+[199Pt]
+[246Bk]
+[193Pt]
+[230U]
+[250Cm]
+[44Ti]
+[175Hf]
+[254Fm]
+[255Fm]
+[257Fm]
+[92Y]
+[188Ir]
+[171Lu]
+[257Md]
+[247Bk]
+[121IH]
+[250Bk]
+[179Lu]
+[224Ac]
+[195Hg]
+[244Am]
+[246Pu]
+[194Au]
+[252Fm]
+[173Hf]
+[246Cm]
+[135Ce]
+[49Cr]
+[248Cf]
+[247Cm]
+[248Cm]
+[174Ta]
+[176Ta]
+[154Tb]
+[172Ta]
+[177Ta]
+[175Ta]
+[180Ta]
+[158Tb]
+[115Ag]
+[189Os]
+[251Cf]
+[145Pr]
+[147Pr]
+[76BrH]
+[102Rh]
+[238Np]
+[185Os]
+[246Am]
+[233Np]
+[166Dy]
+[254Es]
+[244Cf]
+[193Os]
+[245Am]
+[245Bk]
+[239Am]
+[238Am]
+[97Nb]
+[245Pu]
+[254Cf]
+[188W]
+[250Es]
+[251Es]
+[237Am]
+[182Hf]
+[258Md]
+[232Np]
+[238Cm]
+[60Fe]
+[109Pd+2]
+[234Pu]
+[141Ce+3]
+[136Nd]
+[136Pr]
+[173Ta]
+[110Ru]
+[147Tb]
+[253Fm]
+[139Nd]
+[178Re]
+[177Re]
+[200Au]
+[182Re]
+[156Tb]
+[155Tb]
+[157Tb]
+[161Tb]
+[161Ho]
+[167Tm]
+[173Lu]
+[179Ta]
+[171Er]
+[44Sc]
+[49Sc]
+[49V]
+[51Mn]
+[90Nb]
+[88Nb]
+[88Zr]
+[36SH2]
+[174Yb]
+[178Lu]
+[179W]
+[83BrH]
+[107Cd]
+[75BrH]
+[62Co]
+[48Cr]
+[63Zn]
+[102Ag]
+[154Sm]
+[168Er]
+[65Ni]
+[137La]
+[187Ir]
+[144Pm]
+[146Pm]
+[160Gd]
+[166Yb]
+[162Dy]
+[47V]
+[141Nd]
+[141Sm]
+[166Er]
+[150Sm]
+[146Eu]
+[149Eu]
+[174Lu]
+[17NH3]
+[102Ru]
+[170Hf]
+[188Pt]
+[61Ni]
+[56Ni]
+[149Gd]
+[151Gd]
+[141Pm]
+[147Gd]
+[146Gd]
+[161Er]
+[103Ag]
+[145Eu]
+[153Tb]
+[155Dy]
+[184Re]
+[180Os]
+[182Os]
+[186Pt]
+[181Os]
+[181Re]
+[151Tb]
+[178Ta]
+[178W]
+[189Pt]
+[194Hg]
+[145Sm]
+[150Tb]
+[132La]
+[158Gd]
+[104Ag]
+[193Hg]
+[94Ru]
+[137Pr]
+[155Ho]
+[117Cd]
+[99Ru]
+[146Nd]
+[218Rn]
+[95Y]
+[79Kr]
+[120IH]
+[138Pr]
+[100Pd]
+[166Tm]
+[90Mo]
+[151Nd]
+[231U]
+[138Nd]
+[89Nb]
+[98Nb]
+[162Ho]
+[142Sm]
+[186Ta]
+[104Tc]
+[184Ta]
+[185Ta]
+[170Er]
+[107Rh]
+[131La]
+[169Lu]
+[74BrH]
+[150Pm]
+[172Tm]
+[197Pt]
+[230Pu]
+[170Lu]
+[86Zr]
+[176W]
+[177W]
+[101Pd]
+[105Pd]
+[108Pd]
+[149Nd]
+[164Ho]
+[159Ho]
+[167Ho]
+[176Yb]
+[156Sm]
+[77BrH]
+[189Re]
+[99Rh]
+[100Rh]
+[151Pm]
+[232Pa]
+[228Pa]
+[230Pa]
+[66Ni]
+[194Os]
+[135La]
+[138La]
+[141La]
+[142La]
+[195Ir]
+[96Nb]
+[157Ho]
+[183Hf]
+[162Tm]
+[172Er]
+[148Eu]
+[150Eu]
+[15CH4]
+[89Kr]
+[143La]
+[58Ni]
+[61Co]
+[158Eu]
+[165Er]
+[167Yb]
+[173Tm]
+[175Tm]
+[172Hf]
+[172Lu]
+[93Tc]
+[177Yb]
+[124IH]
+[194Ir]
+[147Eu]
+[101Mo]
+[180Hf]
+[189Ir]
+[87Y]
+[43Sc]
+[195Au]
+[112Ag]
+[84BrH]
+[106Ag]
+[109Ag]
+[101Rh]
+[162Yb]
+[228Rn]
+[139Pr]
+[94Y]
+[201Au]
+[40PH3]
+[110Ag+]
+[104Cd]
+[133Ba+2]
+[226Ac]
+[145Gd]
+[186Ir]
+[184Ir]
+[224Rn]
+[185Ir]
+[182Ir]
+[184Hf]
+[200Pt]
+[227Pa]
+[178Yb]
+[72Br-]
+[72BrH]
+[248Am]
+[238Th]
+[161Gd]
+[35S-2]
+[107Ag]
+[FeH6-4]
+[89Sr]
+[SnH3-]
+[SeH3]
+[TeH3+]
+[SbH4+]
+[AsH4+]
+[4He]
+[AsH3-]
+[1HH]
+[3H+]
+[82Rb]
+[85Sr]
+[90Sr]
+[137Cs]
+[133Ba]
+[131Cs]
+[SbH5]
+[224Ra]
+[22Na]
+[210Bi]
+[214Bi]
+[228Ra]
+[127Sb]
+[136Cs]
+[125Sb]
+[134Cs]
+[140Ba]
+[45Ca]
+[206Pb]
+[207Pb]
+[24Na]
+[86Rb]
+[212Bi]
+[208Pb]
+[124Sb]
+[204Pb]
+[44K]
+[129Te]
+[113Sn]
+[204Tl]
+[87Sr]
+[208Tl]
+[87Rb]
+[47Ca]
+[135Cs]
+[216Po]
+[137Ba]
+[207Bi]
+[212Po]
+[79Se]
+[223Ra]
+[86Sr]
+[122Sb]
+[26Al]
+[32Si]
+[126Sn]
+[225Ra]
+[114In]
+[72Ga]
+[132Te]
+[10Be]
+[125Sn]
+[73As]
+[206Bi]
+[117Sn]
+[40Ca]
+[41Ca]
+[89Rb]
+[116In]
+[129Sb]
+[91Sr]
+[71Ge]
+[139Ba]
+[69Ga]
+[120Sb]
+[121Sn]
+[123Sn]
+[131Te]
+[77Ge]
+[135Ba]
+[82Sr]
+[43K]
+[131Ba]
+[92Sr]
+[88Rb]
+[129Cs]
+[144Cs]
+[127Cs]
+[200Tl]
+[202Tl]
+[141Ba]
+[117Sb]
+[116Sb]
+[78As]
+[131Sb]
+[126Sb]
+[128Sb]
+[130Sb]
+[67Ge]
+[68Ge]
+[78Ge]
+[66Ge]
+[223Fr]
+[132Cs]
+[125Cs]
+[138Cs]
+[133Te]
+[84Rb]
+[83Rb]
+[81Rb]
+[142Ba]
+[200Bi]
+[115Sb]
+[194Tl]
+[70Se]
+[112In]
+[118Sb]
+[70Ga]
+[27Mg]
+[202Bi]
+[83Se]
+[9Li]
+[69As]
+[79Rb]
+[81Sr]
+[83Sr]
+[78Se]
+[109In]
+[29Al]
+[118Sn]
+[117In]
+[119Sb]
+[114Sn]
+[138Ba]
+[69Ge]
+[73Ga]
+[74Ge]
+[206Tl]
+[199Tl]
+[130Cs]
+[28Mg]
+[116Te]
+[112Sn]
+[126Ba]
+[211Bi]
+[81Se]
+[127Sn]
+[143Cs]
+[134Te]
+[80Sr]
+[45K]
+[215Po]
+[207Po]
+[111Sn]
+[211Po]
+[128Ba]
+[198Tl]
+[227Ra]
+[213Po]
+[220Ra]
+[128Sn]
+[203Po]
+[205Po]
+[65Ga]
+[197Tl]
+[88Sr]
+[110In]
+[31Si]
+[201Bi]
+[121Te]
+[205Bi]
+[203Bi]
+[195Tl]
+[209Tl]
+[110Sn]
+[222Fr]
+[207At]
+[119In]
+[As@]
+[129IH]
+[157Dy]
+[111IH]
+[230Ra]
+[144Pr+3]
+[SiH3+]
+[3He]
+[AsH5]
+[72Se]
+[95Tc]
+[103Pd]
+[121Sn+2]
+[211Rn]
+[38SH2]
+[127IH]
+[74Br-]
+[133I-]
+[100Tc+4]
+[100Tc]
+[36Cl-]
+[89Y+3]
+[104Rh]
+[152Sm]
+[226Ra]
+[19FH]
+[104Pd]
+[148Gd]
+[157Lu]
+[33SH2]
+[121I-]
+[17FH]
+[71Se]
+[157Sm]
+[148Tb]
+[164Dy]
+[15OH2]
+[15O+]
+[39K]
+[40Ar]
+[50Cr+3]
+[50Cr]
+[52Ti]
+[103Pd+2]
+[130Ba]
+[142Pm]
+[153Gd+3]
+[151Eu]
+[103Rh]
+[124Xe]
+[152Tb]
+[17OH2]
+[20Ne]
+[52Fe]
+[94Zr+4]
+[94Zr]
+[149Pr]
+[16OH2]
+[53Cr+6]
+[53Cr]
+[81Br-]
+[112Pd]
+[125Xe]
+[155Gd]
+[157Gd]
+[168Yb]
+[184Os]
+[166Tb]
+[221Fr]
+[212Ra]
+[75Br-]
+[79Br-]
+[113Ag]
+[23Na]
+[34Cl-]
+[34ClH]
+[38Cl-]
+[56Fe]
+[68Cu]
+[77Br-]
+[90Zr+4]
+[90Zr]
+[102Pd]
+[154Eu+3]
+[57Mn]
+[165Tm]
+[152Dy]
+[217At]
+[77se]
+[13cH-]
+[122Te]
+[156Gd]
+[124Te]
+[53Ni]
+[131Xe]
+[174Hf+4]
+[174Hf]
+[76Se]
+[168Tm]
+[167Dy]
+[154Gd]
+[95Ru]
+[210At]
+[85Br]
+[59Co]
+[122Xe]
+[27Al]
+[54Cr]
+[198Hg]
+[85Rb+]
+[214Tl]
+[229Rn]
+[218Pb]
+[218Bi]
+[167Tm+3]
+[18o+]
+[P@@H+]
+[P@H+]
+[13N+]
+[212Pb+2]
+[217Bi]
+[249Cf+2]
+[18OH3+]
+[90Sr-]
+[Cf+3]
+[200Hg]
+[86Tc]
+[141Pr+3]
+[141Pr]
+[16nH]
+[14NH4+]
+[132Xe]
+[83Kr]
+[70Zn+2]
+[137Ba+2]
+[36Ar]
+[38Ar]
+[21Ne]
+[126Xe]
+[136Xe]
+[128Xe]
+[134Xe]
+[84Kr]
+[86Kr]
+[78Kr]
+[80Kr]
+[82Kr]
+[67Zn+2]
+[65Cu+2]
+[110Te]
+[58Fe+3]
+[142Nd]
+[38K]
+[198Au+3]
+[122IH]
+[38PH3]
+[130I-]
+[40K+]
+[38K+]
+[28Mg+2]
+[208Tl+]
+[13OH2]
+[198Bi]
+[192Bi]
+[194Bi]
+[196Bi]
+[132I-]
+[83Sr+2]
+[169Er+3]
+[122I-]
+[120I-]
+[92Sr+2]
+[126I-]
+[24Mg]
+[84Sr]
+[118Pd+2]
+[118Pd]
+[AsH4]
+[127I-]
+[9C-]
+[11CH3+]
+[17B]
+[7B]
+[4HH]
+[18C-]
+[22CH3-]
+[22CH4]
+[17C-]
+[15CH3]
+[16CH3]
+[11NH3]
+[21NH3]
+[11N-]
+[11NH]
+[16CH]
+[17CH2]
+[99Ru+2]
+[181Ta+2]
+[181Ta]
+[20CH]
+[32PH2]
+[55Fe+2]
+[SH3]
+[S@H]
+[Mn-]
+[IH4]
+[ThH]
+[GaH-]
+[BiH+]
+[EuH2]
+[FeH4-3]
+[FeH6]
+[IH5]
+[NiH+]
+[SrH2]
+[VH4]
+[YH3]
+[seH+]
+
diff --git a/models/smi_ted/finetune/smi_ted_large/clintox/run_finetune_clintox.sh b/models/smi_ted/finetune/smi_ted_large/clintox/run_finetune_clintox.sh
new file mode 100644
index 0000000000000000000000000000000000000000..48fe562011a3e35d4ba56c471589b20716492691
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/clintox/run_finetune_clintox.sh
@@ -0,0 +1,23 @@
+python ../../finetune_classification_multitask.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 100 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/clintox' \
+ --dataset_name clintox \
+ --checkpoints_folder './checkpoints_clintox' \
+ --loss_fn 'bceloss' \
+ --target_metric 'roc-auc' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/esol/run_finetune_esol.sh b/models/smi_ted/finetune/smi_ted_large/esol/run_finetune_esol.sh
new file mode 100644
index 0000000000000000000000000000000000000000..25785cec18f4b11c1e4bc336841509292f5240e1
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/esol/run_finetune_esol.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/esol' \
+ --dataset_name esol \
+ --measure_name 'measured log solubility in mols per litre' \
+ --checkpoints_folder './checkpoints_esol' \
+ --loss_fn 'rmse' \
+ --target_metric 'rmse' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/freesolv/run_finetune_freesolv.sh b/models/smi_ted/finetune/smi_ted_large/freesolv/run_finetune_freesolv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fd930a648830373aa38b973d9418a80ddd62f440
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/freesolv/run_finetune_freesolv.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/freesolv' \
+ --dataset_name freesolv \
+ --measure_name 'expt' \
+ --checkpoints_folder './checkpoints_freesolv' \
+ --loss_fn 'rmse' \
+ --target_metric 'rmse' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/hiv/run_finetune_hiv.sh b/models/smi_ted/finetune/smi_ted_large/hiv/run_finetune_hiv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..978f7fc09ff89b64bca71c726229c3b81fce608d
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/hiv/run_finetune_hiv.sh
@@ -0,0 +1,25 @@
+python ../../finetune_classification.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 1e-7 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/hiv' \
+ --dataset_name hiv \
+ --measure_name 'HIV_active' \
+ --checkpoints_folder './checkpoints_hiv_1e-7' \
+ --loss_fn 'crossentropy' \
+ --target_metric 'roc-auc' \
+ --n_output 2 \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/lipo/run_finetune_lipo.sh b/models/smi_ted/finetune/smi_ted_large/lipo/run_finetune_lipo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a71a544afaa2accea3fd637454e90a294424700d
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/lipo/run_finetune_lipo.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 1e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/lipophilicity' \
+ --dataset_name lipophilicity \
+ --measure_name 'y' \
+ --checkpoints_folder './checkpoints_lipophilicity' \
+ --loss_fn 'rmse' \
+ --target_metric 'rmse' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/load.py b/models/smi_ted/finetune/smi_ted_large/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc1129cfe73b3c9161413e84623bb0ff7294528
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/load.py
@@ -0,0 +1,504 @@
+PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+
+# Transformers
+from fast_transformers.attention import AttentionLayer
+from fast_transformers.events import QKVEvent
+from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer
+from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder
+from fast_transformers.builders.attention_builders import AttentionBuilder
+from fast_transformers.feature_maps import GeneralizedRandomFeatures
+from fast_transformers.masking import LengthMask
+from transformers import BertTokenizer
+
+# Data
+import numpy as np
+import pandas as pd
+
+# Standard library
+from functools import partial
+import regex as re
+import random
+import os
+import gc
+from tqdm import tqdm
+tqdm.pandas()
+
+
+class MolTranBertTokenizer(BertTokenizer):
+ def __init__(self, vocab_file: str = '',
+ do_lower_case=False,
+ unk_token='',
+ sep_token='',
+ pad_token='',
+ cls_token='',
+ mask_token='',
+ **kwargs):
+ super().__init__(vocab_file,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs)
+
+ self.regex_tokenizer = re.compile(PATTERN)
+ self.wordpiece_tokenizer = None
+ self.basic_tokenizer = None
+ with open(vocab_file) as f:
+ self.padding_idx = f.readlines().index(pad_token+'\n')
+
+ def _tokenize(self, text):
+ split_tokens = self.regex_tokenizer.findall(text)
+ return split_tokens
+
+ def convert_idx_to_tokens(self, idx_tensor):
+ tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
+ return tokens
+
+ def convert_tokens_to_string(self, tokens):
+ stopwords = ['', '']
+ clean_tokens = [word for word in tokens if word not in stopwords]
+ out_string = ''.join(clean_tokens)
+ return out_string
+
+ def get_padding_idx(self):
+ return self.padding_idx
+
+ def idx_to_smiles(self, torch_model, idx):
+ '''Convert tokens idx back to SMILES text'''
+ rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
+ flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
+ decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
+ return decoded_smiles
+
+
+## Transformer layers
+class RotaryEmbedding(torch.nn.Module):
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = 0
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self.cos_cached = emb.cos()[None,:, None, :]
+ self.sin_cached = emb.sin()[None,:, None, :]
+
+ return self.cos_cached, self.sin_cached
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+@torch.jit.script
+def apply_rotary_pos_emb(q, k, cos, sin):
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+class RotateAttentionLayer(AttentionLayer):
+ """Rotate attention layer inherits from fast_transformer attention layer.
+ The only thing added is an Embedding encoding, for more information
+ on the attention layer see the fast_transformers code
+ """
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
+ d_values=None, event_dispatcher=""):
+ super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys,
+ d_values=d_values, event_dispatcher=event_dispatcher)
+
+ self.rotaryemb = RotaryEmbedding(d_keys)
+ print('Using Rotation Embedding')
+
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
+ key_lengths):
+ """
+ Using the same frame work as the fast_Transformers attention layer
+ but injecting rotary information to the queries and the keys
+ after the keys and queries are projected.
+ In the argument description we make use of the following sizes
+ - N: the batch size
+ - L: The maximum length of the queries
+ - S: The maximum length of the keys (the actual length per sequence
+ is given by the length mask)
+ - D: The input feature dimensionality passed in the constructor as
+ 'd_model'
+ Arguments
+ ---------
+ queries: (N, L, D) The tensor containing the queries
+ keys: (N, S, D) The tensor containing the keys
+ values: (N, S, D) The tensor containing the values
+ attn_mask: An implementation of BaseMask that encodes where each
+ query can attend to
+ query_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ key_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ Returns
+ -------
+ The new value for each query as a tensor of shape (N, L, D).
+ """
+ # Extract the dimensions into local variables
+ N, L, _ = queries.shape
+ _, S, _ = keys.shape
+ H = self.n_heads
+
+ # Project the queries/keys/values
+ queries = self.query_projection(queries).view(N, L, H, -1)
+ keys = self.key_projection(keys).view(N, S, H, -1)
+ cos, sin = self.rotaryemb(queries)
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
+ values = self.value_projection(values).view(N, S, H, -1)
+ # Let the world know of the qkv
+ self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
+
+
+ # Compute the attention
+ new_values = self.inner_attention(
+ queries,
+ keys,
+ values,
+ attn_mask,
+ query_lengths,
+ key_lengths
+ ).view(N, L, -1)
+
+ # Project the output and return
+ return self.out_projection(new_values)
+
+class RotateEncoderBuilder(BaseTransformerEncoderBuilder):
+ """Build a batch transformer encoder with Relative Rotary embeddings
+ for training or processing of sequences all elements at a time.
+ Example usage:
+ builder = RotateEncoderBuilder()
+ builder.n_layers = 12
+ builder.n_heads = 8
+ builder.feed_forward_dimensions = 1024
+ builder.query_dimensions = 64
+ builder.value_dimensions = 64
+ builder.dropout = 0.1
+ builder.attention_dropout = 0.1
+ builder.attention_type = "linear"
+ transformer = builder.get()
+ """
+ def _get_attention_builder(self):
+ """Return an instance of the appropriate attention builder."""
+ return AttentionBuilder()
+
+ def _get_attention_layer_class(self):
+ """Return the class for the layer that projects queries keys and
+ values."""
+ return RotateAttentionLayer
+
+ def _get_encoder_class(self):
+ """Return the class for the transformer encoder."""
+ return TransformerEncoder
+
+ def _get_encoder_layer_class(self):
+ """Return the class for the transformer encoder layer."""
+ return TransformerEncoderLayer
+
+
+class AutoEncoderLayer(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.encoder = self.Encoder(feature_size, latent_size)
+ self.decoder = self.Decoder(feature_size, latent_size)
+
+ class Encoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(feature_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.lat = nn.Linear(latent_size, latent_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.lat.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.lat(x)
+ return x # -> (N, D)
+
+ class Decoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(latent_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.rec = nn.Linear(latent_size, feature_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.rec.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.rec(x)
+ return x # -> (N, L*D)
+
+
+class LangLayer(nn.Module):
+
+ def __init__(self, n_embd, n_vocab):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.embed = nn.Linear(n_embd, n_embd)
+ self.ln_f = nn.LayerNorm(n_embd)
+ self.head = nn.Linear(n_embd, n_vocab, bias=False)
+
+ def forward(self, tensor):
+ if self.is_cuda_available:
+ self.embed.cuda()
+ self.ln_f.cuda()
+ self.head.cuda()
+ tensor = tensor.cuda()
+ tensor = self.embed(tensor)
+ tensor = F.gelu(tensor)
+ tensor = self.ln_f(tensor)
+ tensor = self.head(tensor)
+ return tensor
+
+
+class Net(nn.Module):
+
+ def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2):
+ super().__init__()
+ self.desc_skip_connection = True
+ self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.relu1 = nn.GELU()
+ self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout2 = nn.Dropout(dropout)
+ self.relu2 = nn.GELU()
+ self.final = nn.Linear(smiles_embed_dim, n_output)
+
+ def forward(self, smiles_emb, multitask=False):
+ x_out = self.fc1(smiles_emb)
+ x_out = self.dropout1(x_out)
+ x_out = self.relu1(x_out)
+
+ if self.desc_skip_connection is True:
+ x_out = x_out + smiles_emb
+
+ z = self.fc2(x_out)
+ z = self.dropout2(z)
+ z = self.relu2(z)
+ if self.desc_skip_connection is True:
+ z = self.final(z + x_out)
+ else:
+ z = self.final(z)
+
+ if multitask:
+ return F.sigmoid(z)
+ return z
+
+
+class MoLEncoder(nn.Module):
+
+ def __init__(self, config, n_vocab, eval=False):
+ super(MoLEncoder, self).__init__()
+
+ # embeddings
+ self.config = config
+ self.tok_emb = nn.Embedding(n_vocab, config['n_embd'])
+ self.drop = nn.Dropout(config['d_dropout'])
+
+ # transformer
+ builder = RotateEncoderBuilder.from_kwargs(
+ n_layers=config['n_layer'],
+ n_heads=config['n_head'],
+ query_dimensions=config['n_embd']//config['n_head'],
+ value_dimensions=config['n_embd']//config['n_head'],
+ feed_forward_dimensions=None,
+ attention_type='linear',
+ # unless we do deterministic_eval here, we will have random outputs
+ feature_map=partial(GeneralizedRandomFeatures,
+ n_dims=config['num_feats'],
+ deterministic_eval=eval),
+ activation='gelu'
+ )
+ self.blocks = builder.get()
+
+ # classification
+ self.lang_model = LangLayer(config['n_embd'], n_vocab)
+
+
+class MoLDecoder(nn.Module):
+
+ def __init__(self, n_vocab, max_len, n_embd, n_gpu=None):
+ super(MoLDecoder, self).__init__()
+
+ self.max_len = max_len
+ self.n_embd = n_embd
+ self.n_gpu = n_gpu
+ self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd)
+ self.lang_model = LangLayer(n_embd, n_vocab)
+
+
+class Smi_ted(nn.Module):
+ """materials.smi-ted-Large 738M Parameters"""
+
+ def __init__(self, tokenizer, config=None, eval=False):
+ super(Smi_ted, self).__init__()
+
+ # configuration
+ self.config = config
+ self.tokenizer = tokenizer
+ self.padding_idx = tokenizer.get_padding_idx()
+ self.n_vocab = len(self.tokenizer.vocab)
+ self.is_cuda_available = torch.cuda.is_available()
+
+ # instantiate modules
+ if self.config:
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
+ self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
+ self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
+
+ def load_checkpoint(self, ckpt_path, n_output, eval=False):
+ # load checkpoint file
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
+
+ # load hyparameters
+ self.config = checkpoint['hparams']
+ self.max_len = self.config['max_len']
+ self.n_embd = self.config['n_embd']
+ self._set_seed(self.config['seed'])
+
+ # instantiate modules
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
+ self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
+ self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
+
+ # load weights
+ if 'state_dict' in checkpoint:
+ if isinstance(checkpoint['state_dict'], list):
+ self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False)
+ self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False)
+ else:
+ self.load_state_dict(checkpoint['state_dict'], strict=False)
+ elif 'MODEL_STATE' in checkpoint:
+ self.load_state_dict(checkpoint['MODEL_STATE'], strict=False)
+
+ # load RNG states each time the model and states are loaded from checkpoint
+ if 'rng' in self.config:
+ rng = self.config['rng']
+ for key, value in rng.items():
+ if key =='torch_state':
+ torch.set_rng_state(value.cpu())
+ elif key =='cuda_state':
+ torch.cuda.set_rng_state(value.cpu())
+ elif key =='numpy_state':
+ np.random.set_state(value)
+ elif key =='python_state':
+ random.setstate(value)
+ else:
+ print('unrecognized state')
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_seed(self, value):
+ print('Random Seed:', value)
+ random.seed(value)
+ torch.manual_seed(value)
+ torch.cuda.manual_seed(value)
+ torch.cuda.manual_seed_all(value)
+ np.random.seed(value)
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+
+ def tokenize(self, smiles):
+ """Tokenize a string into tokens."""
+ if isinstance(smiles, str):
+ batch = [smiles]
+ else:
+ batch = smiles
+
+ tokens = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ max_length=self.max_len,
+ )
+
+ idx = tokens['input_ids'].clone().detach()
+ mask = tokens['attention_mask'].clone().detach()
+
+ if self.is_cuda_available:
+ return idx.cuda(), mask.cuda()
+
+ return idx, mask
+
+ def extract_embeddings(self, smiles):
+ """Extract token and SMILES embeddings."""
+ if self.is_cuda_available:
+ self.encoder.cuda()
+ self.decoder.cuda()
+
+ # tokenizer
+ idx, mask = self.tokenize(smiles)
+
+ # transformer encoder
+ x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.encoder.drop(x)
+ x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1]))
+
+ # add padding
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ mask_embeddings = (token_embeddings * input_mask_expanded)
+ token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0)
+
+ # aggregate token embeddings (similar to mean pooling)
+ # CAUTION: use the embeddings from the autoencoder.
+ smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd))
+
+ return smiles_embeddings
+
+ def __str__(self):
+ return 'smi-ted-Large'
+
+
+def load_smi_ted(folder="./smi_ted_large",
+ ckpt_filename="smi-ted-Large_30.pt",
+ vocab_filename="bert_vocab_curated.txt",
+ n_output=1,
+ eval=False
+ ):
+ tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
+ model = Smi_ted(tokenizer)
+ model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval)
+ print('Vocab size:', len(tokenizer.vocab))
+ print(f'[FINETUNE MODE - {str(model)}]')
+ return model
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9debb92f3ad1748661a8aef9b69b78cb045345dc
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E1-CAM' \
+ --checkpoints_folder './checkpoints_QM8-E1-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..589e20897b4cd5871c65fea35617a47aa5cbc5df
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E1-CC2' \
+ --checkpoints_folder './checkpoints_QM8-E1-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d5c8c90452ccf0b3069887ad0c06a0ea192e95af
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E1-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E1-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-E1-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..147591f60a1a0d6921c8b9df5e0d895ddd2f7839
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E2-CAM' \
+ --checkpoints_folder './checkpoints_QM8-E2-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c471c2491b1bd08cefc47a503faf66e1ae12a713
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E2-CC2' \
+ --checkpoints_folder './checkpoints_QM8-E2-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e1527232ef9cfe39cf4576e6def0927c1d4b39fc
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_E2-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E2-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-E2-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..64a3297c5f2ff839507614ba94072791ccb60436
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f1-CAM' \
+ --checkpoints_folder './checkpoints_QM8-f1-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a326d8686aa7d12f2162ab48323131947e7d88f5
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f1-CC2' \
+ --checkpoints_folder './checkpoints_QM8-f1-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..210cbf1928ef9b1f356c3a14813c81a0891e349b
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f1-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f1-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-f1-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c3d38f5854afcd95df057464d278e530a50e90bb
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f2-CAM' \
+ --checkpoints_folder './checkpoints_QM8-f2-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..597f9c37937169379d3b8e28cc221d9637fc566a
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f2-CC2' \
+ --checkpoints_folder './checkpoints_QM8-f2-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f4eaaa4cb96bef5b34ee38e508a36a6819016e49
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm8/run_finetune_f2-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f2-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-f2-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_alpha.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_alpha.sh
new file mode 100644
index 0000000000000000000000000000000000000000..54ddaa254018f5090abed6bf930d74beaccbe29d
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_alpha.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'alpha' \
+ --checkpoints_folder './checkpoints_QM9-alpha' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_cv.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_cv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..78accb629f42542fde4c26e39da6e1e1453cee7e
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_cv.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'cv' \
+ --checkpoints_folder './checkpoints_QM9-cv' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_g298.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_g298.sh
new file mode 100644
index 0000000000000000000000000000000000000000..19843c43b83626d61764d152138b3df1d2d62023
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_g298.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'g298' \
+ --checkpoints_folder './checkpoints_QM9-g298' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_gap.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_gap.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d2726bb325164a18a578a4103de314deac704973
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_gap.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'gap' \
+ --checkpoints_folder './checkpoints_QM9-gap' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_h298.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_h298.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d638fe5a899d4b091714c133b2ff3d1ac7e72991
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_h298.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'h298' \
+ --checkpoints_folder './checkpoints_QM9-h298' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_homo.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_homo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..44b39ed0e5fc3130d97246712da0823adcf75b5c
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_homo.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'homo' \
+ --checkpoints_folder './checkpoints_QM9-homo' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_lumo.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_lumo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4619898a4587d23b8f11a52c1e147f2b491bbc94
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_lumo.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'lumo' \
+ --checkpoints_folder './checkpoints_QM9-lumo' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_mu.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_mu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b8c5b8aa012c6668bfc7d41781fc5ea66acc68ec
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_mu.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'mu' \
+ --checkpoints_folder './checkpoints_QM9-mu' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_r2.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_r2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1d7c630b89a226ec42ce68fa53e4f6f916a2eb7d
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_r2.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'r2' \
+ --checkpoints_folder './checkpoints_QM9-r2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u0.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u0.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9769ec917660bb88444f0e7de63e36c15576146
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u0.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'u0' \
+ --checkpoints_folder './checkpoints_QM9-u0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u298.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u298.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c1eac1faae458ab03ce06c4aa729ac96387f2921
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_u298.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'u298' \
+ --checkpoints_folder './checkpoints_QM9-u298' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_zpve.sh b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_zpve.sh
new file mode 100644
index 0000000000000000000000000000000000000000..aa31850b5a618fbbef203c0f86f6680faec960e3
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/qm9/run_finetune_zpve.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'zpve' \
+ --checkpoints_folder './checkpoints_QM9-zpve' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/sider/run_finetune_sider.sh b/models/smi_ted/finetune/smi_ted_large/sider/run_finetune_sider.sh
new file mode 100644
index 0000000000000000000000000000000000000000..94522a262f20bd736fdabfb9d00f2004ae644bd6
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/sider/run_finetune_sider.sh
@@ -0,0 +1,23 @@
+python ../../finetune_classification_multitask.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/sider' \
+ --dataset_name sider \
+ --checkpoints_folder './checkpoints_sider' \
+ --loss_fn 'bceloss' \
+ --target_metric 'roc-auc' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_large/tox21/run_finetune_tox21.sh b/models/smi_ted/finetune/smi_ted_large/tox21/run_finetune_tox21.sh
new file mode 100644
index 0000000000000000000000000000000000000000..84302a36fcc7babfde0233f5f12b4a03cb63c02c
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_large/tox21/run_finetune_tox21.sh
@@ -0,0 +1,23 @@
+python ../../finetune_classification_multitask.py \
+ --n_batch 32 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 1e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 100 \
+ --num_feats 32 \
+ --smi_ted_version 'v2' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Large_11.pt' \
+ --data_root '../../moleculenet/tox21' \
+ --dataset_name tox21 \
+ --checkpoints_folder './checkpoints_tox21' \
+ --loss_fn 'bceloss' \
+ --target_metric 'roc-auc' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/bace/run_finetune_bace.sh b/models/smi_ted/finetune/smi_ted_light/bace/run_finetune_bace.sh
new file mode 100644
index 0000000000000000000000000000000000000000..da1b97953d4ad65a64c3fd68f495ea8ca91f5bde
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/bace/run_finetune_bace.sh
@@ -0,0 +1,25 @@
+python ../../finetune_classification.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/bace' \
+ --dataset_name bace \
+ --measure_name 'Class' \
+ --checkpoints_folder './checkpoints_bace' \
+ --loss_fn 'crossentropy' \
+ --target_metric 'roc-auc' \
+ --n_output 2 \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/bbbp/run_finetune_bbbp.sh b/models/smi_ted/finetune/smi_ted_light/bbbp/run_finetune_bbbp.sh
new file mode 100644
index 0000000000000000000000000000000000000000..860d657ca01b36ee78fe031dfb5f44880b8eda3f
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/bbbp/run_finetune_bbbp.sh
@@ -0,0 +1,25 @@
+python ../../finetune_classification.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/bbbp' \
+ --dataset_name bbbp \
+ --measure_name 'p_np' \
+ --checkpoints_folder './checkpoints_bbbp' \
+ --loss_fn 'crossentropy' \
+ --target_metric 'roc-auc' \
+ --n_output 2 \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/bert_vocab_curated.txt b/models/smi_ted/finetune/smi_ted_light/bert_vocab_curated.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/bert_vocab_curated.txt
@@ -0,0 +1,2393 @@
+
+
+
+
+C
+c
+(
+)
+1
+O
+N
+2
+=
+n
+3
+[C@H]
+[C@@H]
+F
+S
+4
+Cl
+-
+o
+s
+[nH]
+#
+/
+Br
+[C@]
+[C@@]
+[N+]
+[O-]
+5
+\
+.
+I
+6
+[S@]
+[S@@]
+P
+[N-]
+[Si]
+7
+[n+]
+[2H]
+8
+[NH+]
+B
+9
+[C-]
+[Na+]
+[Cl-]
+[c-]
+[CH]
+%10
+[NH2+]
+[P+]
+[B]
+[I-]
+%11
+[CH2-]
+[O+]
+[NH3+]
+[C]
+[Br-]
+[IH2]
+[S-]
+[cH-]
+%12
+[nH+]
+[B-]
+[K+]
+[Sn]
+[Se]
+[CH-]
+[HH]
+[Y]
+[n-]
+[CH3-]
+[SiH]
+[S+]
+%13
+[SiH2]
+[Li+]
+[NH-]
+%14
+[Na]
+[CH2]
+[O-2]
+[U+2]
+[W]
+[Al]
+[P@]
+[Fe+2]
+[PH+]
+%15
+[Cl+3]
+[Zn+2]
+[Ir]
+[Mg+2]
+[Pt+2]
+[OH2+]
+[As]
+[Fe]
+[OH+]
+[Zr+2]
+[3H]
+[Ge]
+[SiH3]
+[OH-]
+[NH4+]
+[Cu+2]
+[P@@]
+p
+[Pt]
+%16
+[Ca+2]
+[Zr]
+[F-]
+[C+]
+[Ti]
+[P-]
+[V]
+[se]
+[U]
+[O]
+[Ni+2]
+[Zn]
+[Co]
+[Ni]
+[Pd+2]
+[Cu]
+%17
+[Cu+]
+[Te]
+[H+]
+[CH+]
+[Li]
+[Pd]
+[Mo]
+[Ru+2]
+[o+]
+[Re]
+[SH+]
+%18
+[Ac]
+[Cr]
+[NH2-]
+[K]
+[13CH2]
+[c]
+[Zr+4]
+[Tl]
+[13C]
+[Mn]
+[N@+]
+[Hg]
+[Rh]
+[Ti+4]
+[Sb]
+[Co+2]
+[Ag+]
+[Ru]
+%19
+[N@@+]
+[Ti+2]
+[Al+3]
+[Pb]
+[I+]
+[18F]
+[s+]
+[Rb+]
+[Ba+2]
+[H-]
+[Fe+3]
+[Ir+3]
+[13cH]
+%20
+[AlH2]
+[Au+]
+[13c]
+[SH2+]
+[Sn+2]
+[Mn+2]
+[Si-]
+[Ag]
+[N]
+[Bi]
+%21
+[In]
+[CH2+]
+[Y+3]
+[Ga]
+%22
+[Co+3]
+[Au]
+[13CH3]
+[Mg]
+[Cs+]
+[W+2]
+[Hf]
+[Zn+]
+[Se-]
+[S-2]
+[Ca]
+[pH]
+[ClH+]
+[Ti+3]
+%23
+[Ru+]
+[SH-]
+[13CH]
+[IH+]
+[Hf+4]
+[Rf]
+[OH3+]
+%24
+[Pt+4]
+[Zr+3]
+[PH3+]
+[Sr+2]
+[Cd+2]
+[Cd]
+%25
+[Os]
+[BH-]
+[Sn+4]
+[Cr+3]
+[Ru+3]
+[PH2+]
+[Rh+2]
+[V+2]
+%26
+[Gd+3]
+[Pb+2]
+[PH]
+[Hg+]
+[Mo+2]
+[AlH]
+[Sn+]
+%27
+[Pd+]
+b
+[Rh+3]
+[Hg+2]
+[15NH]
+[14C]
+%28
+[Mn+3]
+[Si+]
+[SeH]
+[13C@H]
+[NH]
+[Ga+3]
+[SiH-]
+[13C@@H]
+[Ce]
+[Au+3]
+[Bi+3]
+[15N]
+%29
+[BH3-]
+[14cH]
+[Ti+]
+[Gd]
+[cH+]
+[Cr+2]
+[Sb-]
+%30
+[Be+2]
+[Al+]
+[te]
+[11CH3]
+[Sm]
+[Pr]
+[La]
+%31
+[Al-]
+[Ta]
+[125I]
+[BH2-]
+[Nb]
+[Si@]
+%32
+[14c]
+[Sb+3]
+[Ba]
+%33
+[Os+2]
+[Si@@]
+[La+3]
+[15n]
+[15NH2]
+[Nd+3]
+%34
+[14CH2]
+[18O]
+[Nd]
+[GeH]
+[Ni+3]
+[Eu]
+[Dy+3]
+[Sc]
+%36
+[Se-2]
+[As+]
+%35
+[AsH]
+[Tb]
+[Sb+5]
+[Se+]
+[Ce+3]
+[c+]
+[In+3]
+[SnH]
+[Mo+4]
+%37
+[V+4]
+[Eu+3]
+[Hf+2]
+%38
+[Pt+]
+[p+]
+[123I]
+[Tl+]
+[Sm+3]
+%39
+[Yb+3]
+%40
+[Yb]
+[Os+]
+%41
+[10B]
+[Sc+3]
+[Al+2]
+%42
+[Sr]
+[Tb+3]
+[Po]
+[Tc]
+[PH-]
+[AlH3]
+[Ar]
+[U+4]
+[SnH2]
+[Cl+2]
+[si]
+[Fe+]
+[14CH3]
+[U+3]
+[Cl+]
+%43
+[GeH2]
+%44
+[Er+3]
+[Mo+3]
+[I+2]
+[Fe+4]
+[99Tc]
+%45
+[11C]
+%46
+[SnH3]
+[S]
+[Te+]
+[Er]
+[Lu+3]
+[11B]
+%47
+%48
+[P]
+[Tm]
+[Th]
+[Dy]
+[Pr+3]
+[Ta+5]
+[Nb+5]
+[Rb]
+[GeH3]
+[Br+2]
+%49
+[131I]
+[Fm]
+[Cs]
+[BH4-]
+[Lu]
+[15nH]
+%50
+[Ru+6]
+[b-]
+[Ho]
+[Th+4]
+[Ru+4]
+%52
+[14CH]
+%51
+[Cr+6]
+[18OH]
+[Ho+3]
+[Ce+4]
+[Bi+2]
+[Co+]
+%53
+[Yb+2]
+[Fe+6]
+[Be]
+%54
+[SH3+]
+[Np]
+[As-]
+%55
+[14C@@H]
+[Ir+2]
+[GaH3]
+[p-]
+[GeH4]
+[Sn+3]
+[Os+4]
+%56
+[14C@H]
+[sH+]
+[19F]
+[Eu+2]
+[TlH]
+%57
+[Cr+4]
+%58
+[B@@-]
+[SiH+]
+[At]
+[Am]
+[Fe+5]
+[AsH2]
+[Si+4]
+[B@-]
+[Pu]
+[SbH]
+[P-2]
+[Tm+3]
+*
+%59
+[se+]
+[IH-]
+%60
+[oH+]
+[1H]
+[15N+]
+[124I]
+[S@@+]
+[P-3]
+[H]
+[IH2+]
+[TeH]
+[Xe]
+[PH4+]
+[Cr+]
+[Cm]
+[I+3]
+%61
+[Nb+2]
+[Ru+5]
+%62
+[Ta+2]
+[Tc+4]
+[CH3+]
+[Pm]
+[Si@H]
+[No]
+%63
+[Cr+5]
+[Th+2]
+[Zn-2]
+[13C@]
+[Lr]
+%64
+[99Tc+3]
+%65
+[13C@@]
+%66
+[Fe-]
+[17O]
+[siH]
+[Sb+]
+[OH]
+[IH]
+[11CH2]
+[Cf]
+[SiH2+]
+[Gd+2]
+[In+]
+[Si@@H]
+[Mn+]
+[99Tc+4]
+[Ga-]
+%67
+[S@+]
+[Ge+4]
+[Tl+3]
+[16OH]
+%68
+[2H-]
+[Ra]
+[si-]
+[NiH2]
+[P@@H]
+[Rh+]
+[12C]
+[35S]
+[32P]
+[SiH2-]
+[AlH2+]
+[16O]
+%69
+[BiH]
+[BiH2]
+[Zn-]
+[BH]
+[Tc+3]
+[Ir+]
+[Ni+]
+%70
+[InH2]
+[InH]
+[Nb+3]
+[PbH]
+[Bi+]
+%71
+[As+3]
+%72
+[18O-]
+[68Ga+3]
+%73
+[Pa]
+[76Br]
+[Tc+5]
+[pH+]
+[64Cu+2]
+[Ru+8]
+%74
+[PH2-]
+[Si+2]
+[17OH]
+[RuH]
+[111In+3]
+[AlH+]
+%75
+%76
+[W+]
+[SbH2]
+[PoH]
+[Ru-]
+[XeH]
+[Tc+2]
+[13C-]
+[Br+]
+[Pt-2]
+[Es]
+[Cu-]
+[Mg+]
+[3HH]
+[P@H]
+[ClH2+]
+%77
+[SH]
+[Au-]
+[2HH]
+%78
+[Sn-]
+[11CH]
+[PdH2]
+0
+[Os+6]
+%79
+[Mo+]
+%80
+[al]
+[PbH2]
+[64Cu]
+[Cl]
+[12CH3]
+%81
+[Tc+7]
+[11c]
+%82
+[Li-]
+[99Tc+5]
+[He]
+[12c]
+[Kr]
+[RuH+2]
+[35Cl]
+[Pd-2]
+[GaH2]
+[4H]
+[Sg]
+[Cu-2]
+[Br+3]
+%83
+[37Cl]
+[211At]
+[IrH+2]
+[Mt]
+[Ir-2]
+[In-]
+[12cH]
+[12CH2]
+[RuH2]
+[99Tc+7]
+%84
+[15n+]
+[ClH2+2]
+[16N]
+[111In]
+[Tc+]
+[Ru-2]
+[12CH]
+[si+]
+[Tc+6]
+%85
+%86
+[90Y]
+[Pd-]
+[188Re]
+[RuH+]
+[NiH]
+[SiH3-]
+[14n]
+[CH3]
+[14N]
+[10BH2]
+%88
+%89
+%90
+[34S]
+[77Br]
+[GaH]
+[Br]
+[Ge@]
+[B@@H-]
+[CuH]
+[SiH4]
+[3H-]
+%87
+%91
+%92
+[67Cu]
+[I]
+[177Lu]
+[ReH]
+[67Ga+3]
+[Db]
+[177Lu+3]
+[AlH2-]
+[Si+3]
+[Ti-2]
+[RuH+3]
+[al+]
+[68Ga]
+[2H+]
+[B@H-]
+[WH2]
+[OsH]
+[Ir-3]
+[AlH-]
+[Bk]
+[75Se]
+[14C@]
+[Pt-]
+[N@@H+]
+[Nb-]
+[13NH2]
+%93
+[186Re]
+[Tb+4]
+[PtH]
+[IrH2]
+[Hg-2]
+[AlH3-]
+[PdH+]
+[Md]
+[RhH+2]
+[11cH]
+[Co-2]
+[15N-]
+[ZrH2]
+%94
+[Hg-]
+[127I]
+[AsH2+]
+[MoH2]
+[Te+4]
+[14C@@]
+[As+5]
+[SnH+3]
+[Ge@@]
+[6Li+]
+[WH]
+[Ne]
+[14NH2]
+[14NH]
+[12C@@H]
+[Os+7]
+[RhH]
+[Al-3]
+[SnH+]
+[15NH3+]
+[Zr+]
+[197Hg+]
+%95
+%96
+[90Y+3]
+[Os-2]
+[98Tc+5]
+[15NH3]
+[bH-]
+[33P]
+[Zr-2]
+[15O]
+[Rh-]
+[PbH3]
+[PH2]
+[Ni-]
+[CuH+]
+%97
+%98
+%99
+[Os+5]
+[PtH+]
+[ReH4]
+[16NH]
+[82Br]
+[W-]
+[18F-]
+[15NH4+]
+[Se+4]
+[SeH-]
+[SH4]
+[67Cu+2]
+[12C@H]
+[AsH3]
+[HgH]
+[10B-]
+[99Tc+6]
+[117Sn+4]
+[Te@]
+[P@+]
+[35SH]
+[SeH+]
+[Ni-2]
+[Al-2]
+[TeH2]
+[Bh]
+[99Tc+2]
+[Os+8]
+[PH-2]
+[7Li+]
+[14nH]
+[AlH+2]
+[18FH]
+[SnH4]
+[18O-2]
+[IrH]
+[13N]
+[Te@@]
+[Rh-3]
+[15NH+]
+[AsH3+]
+[SeH2]
+[AsH+]
+[CoH2]
+[16NH2]
+[AsH-]
+[203Hg+]
+[P@@+]
+[166Ho+3]
+[60Co+3]
+[13CH2-]
+[SeH2+]
+[75Br]
+[TlH2]
+[80Br]
+[siH+]
+[Ca+]
+[153Sm+3]
+[PdH]
+[225Ac]
+[13CH3-]
+[AlH4-]
+[FeH]
+[13CH-]
+[14C-]
+[11C-]
+[153Sm]
+[Re-]
+[te+]
+[13CH4]
+[ClH+2]
+[8CH2]
+[99Mo]
+[ClH3+3]
+[SbH3]
+[25Mg+2]
+[16N+]
+[SnH2+]
+[PH4]
+[11C@H]
+[122I]
+[Re-2]
+[RuH2+2]
+[ZrH]
+[Bi-]
+[Pr+]
+[Rn]
+[Fr]
+[36Cl]
+[18o]
+[YH]
+[79Br]
+[121I]
+[113In+3]
+[InH4-]
+[TaH]
+[RhH2]
+[Ta-]
+[67Ga]
+[ZnH+]
+[SnH2-]
+[OsH2]
+[16F]
+[FeH2]
+[14O]
+[PbH2+2]
+[BH2]
+[6H]
+[125Te]
+[197Hg]
+[TaH2]
+[TaH3]
+[76As]
+[Nb-2]
+[14N+]
+[125I-]
+[33S]
+[IH2+2]
+[NH2]
+[PtH2]
+[MnH]
+[19C]
+[17F]
+[1H-]
+[SnH4+2]
+[Mn-2]
+[15NH2+]
+[TiH2]
+[ReH7]
+[Cd-2]
+[Fe-3]
+[SH2]
+[17O-]
+[siH-]
+[CoH+]
+[VH]
+[10BH]
+[Ru-3]
+[13O]
+[5H]
+[CoH]
+[PH5]
+[15n-]
+[153Gd]
+[12C@]
+[11CH3-]
+[IrH3]
+[RuH3]
+[74Se]
+[Se@]
+[Hf+]
+[77Se]
+[166Ho]
+[59Fe+2]
+[203Hg]
+[18OH-]
+[8CH]
+[12C@@]
+[11CH4]
+[15C]
+[249Cf]
+[PbH4]
+[64Zn]
+[PH3]
+[99Tc+]
+[14c-]
+[149Pm]
+[IrH4]
+[Se@@]
+[13OH]
+[14CH3-]
+[28Si]
+[Rh-2]
+[Fe-2]
+[131I-]
+[51Cr]
+[62Cu+2]
+[81Br]
+[121Sb]
+[7Li]
+[89Zr+4]
+[SbH3+]
+[11C@@H]
+[98Tc]
+[59Fe+3]
+[BiH2+]
+[SbH+]
+[TiH]
+[14NH3]
+[15OH]
+[119Sn]
+[201Hg]
+[MnH+]
+[201Tl]
+[51Cr+3]
+[123I-]
+[MoH]
+[AlH6-3]
+[MnH2]
+[WH3]
+[213Bi+3]
+[SnH2+2]
+[123IH]
+[13CH+]
+[Zr-]
+[74As]
+[13C+]
+[32P+]
+[KrH]
+[SiH+2]
+[ClH3+2]
+[13NH]
+[9CH2]
+[ZrH2+2]
+[87Sr+2]
+[35s]
+[239Pu]
+[198Au]
+[241Am]
+[203Hg+2]
+[V+]
+[YH2]
+[SH5]
+[195Pt]
+[203Pb]
+[RuH4]
+[ThH2]
+[AuH]
+[66Ga+3]
+[11B-]
+[F]
+[24Na+]
+[85Sr+2]
+[201Tl+]
+[14CH4]
+[32S]
+[TeH2+]
+[ClH2+3]
+[AgH]
+[Ge@H]
+[44Ca+2]
+[Os-]
+[31P]
+[15nH+]
+[SbH4]
+[TiH+]
+[Ba+]
+[57Co+2]
+[Ta+]
+[125IH]
+[77As]
+[129I]
+[Fe-4]
+[Ta-2]
+[19O]
+[12O]
+[BiH3]
+[237Np]
+[252Cf]
+[86Y]
+[Cr-2]
+[89Y]
+[195Pt+2]
+[si+2]
+[58Fe+2]
+[Hs]
+[S@@H]
+[OsH6]
+[GdH2]
+[IH3]
+[8CH4]
+[164Dy+3]
+[47Ca+2]
+[57Co]
+[NbH2]
+[ReH2]
+[ZnH2]
+[CrH2]
+[17NH]
+[ZrH3]
+[RhH3]
+[12C-]
+[18O+]
+[Bi-2]
+[ClH4+3]
+[Ni-3]
+[Ag-]
+[111In-]
+[Mo-2]
+[55Fe+3]
+[204Hg+]
+[35Cl-]
+[211Pb]
+[75Ge]
+[8B]
+[TeH3]
+[SnH3+]
+[Zr-3]
+[28F]
+[249Bk]
+[169Yb]
+[34SH]
+[6Li]
+[94Tc]
+[197Au]
+[195Pt+4]
+[169Yb+3]
+[32Cl]
+[82Se]
+[159Gd+3]
+[213Bi]
+[CoH+2]
+[36S]
+[35P]
+[Ru-4]
+[Cr-3]
+[60Co]
+[1H+]
+[18CH2]
+[Cd-]
+[152Sm+3]
+[106Ru]
+[238Pu]
+[220Rn]
+[45Ca+2]
+[89Sr+2]
+[239Np]
+[90Sr+2]
+[137Cs+]
+[165Dy]
+[68GaH3]
+[65Zn+2]
+[89Zr]
+[BiH2+2]
+[62Cu]
+[165Dy+3]
+[238U]
+[105Rh+3]
+[70Zn]
+[12B]
+[12OH]
+[18CH]
+[17CH]
+[OsH3]
+[SbH-]
+[SH6]
+[AlH2-2]
+[42K]
+[76Br-]
+[71As]
+[NbH3]
+[ReH3]
+[OsH-]
+[WH4]
+[MoH3]
+[OsH4]
+[RuH6]
+[PtH3]
+[CuH2]
+[CoH3]
+[TiH4]
+[64Zn+2]
+[Si-2]
+[79BrH]
+[14CH2-]
+[PtH2+2]
+[Os-3]
+[29Si]
+[Ti-]
+[Se+6]
+[22Na+]
+[42K+]
+[131Cs+]
+[86Rb+]
+[134Cs+]
+[209Po]
+[208Po]
+[81Rb+]
+[203Tl+]
+[Zr-4]
+[148Sm]
+[147Sm]
+[37Cl-]
+[12CH4]
+[Ge@@H]
+[63Cu]
+[13CH2+]
+[AsH2-]
+[CeH]
+[SnH-]
+[UH]
+[9c]
+[21CH3]
+[TeH+]
+[57Co+3]
+[8BH2]
+[12BH2]
+[19BH2]
+[9BH2]
+[YbH2]
+[CrH+2]
+[208Bi]
+[152Gd]
+[61Cu]
+[115In]
+[60Co+2]
+[13NH2-]
+[120I]
+[18OH2]
+[75SeH]
+[SbH2+]
+[144Ce]
+[16n]
+[113In]
+[22nH]
+[129I-]
+[InH3]
+[32PH3]
+[234U]
+[235U]
+[59Fe]
+[82Rb+]
+[65Zn]
+[244Cm]
+[147Pm]
+[91Y]
+[237Pu]
+[231Pa]
+[253Cf]
+[127Te]
+[187Re]
+[236Np]
+[235Np]
+[72Zn]
+[253Es]
+[159Dy]
+[62Zn]
+[101Tc]
+[149Tb]
+[124I-]
+[SeH3+]
+[210Pb]
+[40K]
+[210Po]
+[214Pb]
+[218Po]
+[214Po]
+[7Be]
+[212Pb]
+[205Pb]
+[209Pb]
+[123Te]
+[202Pb]
+[72As]
+[201Pb]
+[70As]
+[73Ge]
+[200Pb]
+[198Pb]
+[66Ga]
+[73Se]
+[195Pb]
+[199Pb]
+[144Ce+3]
+[235U+2]
+[90Tc]
+[114In+3]
+[128I]
+[100Tc+]
+[82Br-]
+[191Pt+2]
+[191Pt+4]
+[193Pt+4]
+[31PH3]
+[125I+2]
+[131I+2]
+[125Te+4]
+[82Sr+2]
+[149Sm]
+[81BrH]
+[129Xe]
+[193Pt+2]
+[123I+2]
+[Cr-]
+[Co-]
+[227Th+4]
+[249Cf+3]
+[252Cf+3]
+[187Os]
+[16O-]
+[17O+]
+[16OH-]
+[98Tc+7]
+[58Co+2]
+[69Ga+3]
+[57Fe+2]
+[43K+]
+[16C]
+[52Fe+3]
+[SeH5]
+[194Pb]
+[196Pb]
+[197Pb]
+[213Pb]
+[9B]
+[19B]
+[11CH-]
+[9CH]
+[20OH]
+[25OH]
+[8cH]
+[TiH+3]
+[SnH6+3]
+[N@H+]
+[ZnH]
+[VH3]
+[52Mn+2]
+[64Ga]
+[13B]
+[216Bi]
+[117Sn+2]
+[232Th]
+[SnH+2]
+[BiH5]
+[77Kr]
+[103Cd]
+[62Ni]
+[LaH3]
+[SmH3]
+[EuH3]
+[MoH5]
+[64Ni]
+[66Zn]
+[68Zn]
+[186W]
+[FeH4]
+[MoH4]
+[HgH2]
+[15NH2-]
+[UH2]
+[204Hg]
+[GaH4-]
+[ThH4]
+[WH6]
+[PtH4]
+[VH2]
+[UH3]
+[FeH3]
+[RuH5]
+[BiH4]
+[80Br-]
+[CeH3]
+[37ClH]
+[157Gd+3]
+[205Tl]
+[203Tl]
+[62Cu+]
+[64Cu+]
+[61Cu+]
+[37SH2]
+[30Si]
+[28Al]
+[19OH2]
+[8He]
+[6He]
+[153Pm]
+[209Bi]
+[66Zn+2]
+[10CH4]
+[191Ir]
+[66Cu]
+[16O+]
+[25O]
+[10c]
+[Co-3]
+[Sn@@]
+[17OH-]
+[206Po]
+[204Po]
+[202Po]
+[201Po]
+[200Po]
+[199Po]
+[198Po]
+[197Po]
+[196Po]
+[195Po]
+[194Po]
+[193Po]
+[192Po]
+[191Po]
+[190Po]
+[217Po]
+[BiH4-]
+[TeH4]
+[222Ra]
+[62Ga]
+[39Ar]
+[144Sm]
+[58Fe]
+[153Eu]
+[85Rb]
+[171Yb]
+[172Yb]
+[114Cd]
+[51Fe]
+[142Ce]
+[207Tl]
+[92Mo]
+[115Sn]
+[140Ce]
+[202Hg]
+[180W]
+[182W]
+[183W]
+[184W]
+[96Mo]
+[47Ti]
+[111Cd]
+[143Nd]
+[145Nd]
+[126Te]
+[128Te]
+[130Te]
+[185Re]
+[97Mo]
+[98Mo]
+[183Re]
+[52V]
+[80Se]
+[87Kr]
+[137Xe]
+[196Au]
+[146Ce]
+[88Kr]
+[51Ti]
+[138Xe]
+[112Cd]
+[116Sn]
+[120Sn]
+[28SiH3]
+[35S-]
+[15NH-]
+[13CH3+]
+[34S+]
+[34s]
+[SiH4-]
+[100Tc+5]
+[NiH2+2]
+[239Th]
+[186Lu]
+[AuH3]
+[I@@-]
+[XeH2]
+[B+]
+[16CH2]
+[8C]
+[TaH5]
+[FeH4-]
+[19C@H]
+[10NH]
+[FeH6-3]
+[22CH]
+[25N]
+[25N+]
+[25N-]
+[21CH2]
+[18cH]
+[113I]
+[ScH3]
+[30PH3]
+[43Ca+2]
+[41Ca+2]
+[106Cd]
+[122Sn]
+[18CH3]
+[58Co+3]
+[98Tc+4]
+[70Ge]
+[76Ge]
+[108Cd]
+[116Cd]
+[130Xe]
+[94Mo]
+[124Sn]
+[186Os]
+[188Os]
+[190Os]
+[192Os]
+[106Pd]
+[110Pd]
+[120Te]
+[132Ba]
+[134Ba]
+[136Ba]
+[136Ce]
+[138Ce]
+[156Dy]
+[158Dy]
+[160Dy]
+[163Dy]
+[162Er]
+[164Er]
+[167Er]
+[176Hf]
+[26Mg]
+[144Nd]
+[150Nd]
+[41K]
+[46Ti]
+[48Ti]
+[49Ti]
+[50Ti]
+[170Yb]
+[173Yb]
+[91Zr]
+[92Zr]
+[96Zr]
+[34S-]
+[CuH2-]
+[38Cl]
+[25Mg]
+[51V]
+[93Nb]
+[95Mo]
+[45Sc]
+[123Sb]
+[139La]
+[9Be]
+[99Y+3]
+[99Y]
+[156Ho]
+[67Zn]
+[144Ce+4]
+[210Tl]
+[42Ca]
+[54Fe]
+[193Ir]
+[92Nb]
+[141Cs]
+[52Cr]
+[35ClH]
+[46Ca]
+[139Cs]
+[65Cu]
+[71Ga]
+[60Ni]
+[16NH3]
+[148Nd]
+[72Ge]
+[161Dy]
+[49Ca]
+[43Ca]
+[8Be]
+[48Ca]
+[44Ca]
+[120Xe]
+[80Rb]
+[215At]
+[180Re]
+[146Sm]
+[19Ne]
+[74Kr]
+[134La]
+[76Kr]
+[219Fr]
+[121Xe]
+[220Fr]
+[216At]
+[223Ac]
+[218At]
+[37Ar]
+[135I]
+[110Cd]
+[94Tc+7]
+[86Y+3]
+[135I-]
+[15O-2]
+[151Eu+3]
+[161Tb+3]
+[197Hg+2]
+[109Cd+2]
+[191Os+4]
+[170Tm+3]
+[205Bi+3]
+[233U+4]
+[126Sb+3]
+[127Sb+3]
+[132Cs+]
+[136Eu+3]
+[136Eu]
+[125Sn+4]
+[175Yb+3]
+[100Mo]
+[22Ne]
+[13c-]
+[13NH4+]
+[17C]
+[9C]
+[31S]
+[31SH]
+[133I]
+[126I]
+[36SH]
+[30S]
+[32SH]
+[19CH2]
+[19c]
+[18c]
+[15F]
+[10C]
+[RuH-]
+[62Zn+2]
+[32ClH]
+[33ClH]
+[78BrH]
+[12Li+]
+[12Li]
+[233Ra]
+[68Ge+4]
+[44Sc+3]
+[91Y+3]
+[106Ru+3]
+[PoH2]
+[AtH]
+[55Fe]
+[233U]
+[210PoH2]
+[230Th]
+[228Th]
+[222Rn]
+[35SH2]
+[227Th]
+[192Ir]
+[133Xe]
+[81Kr]
+[95Zr]
+[240Pu]
+[54Mn]
+[103Ru]
+[95Nb]
+[109Cd]
+[141Ce]
+[85Kr]
+[110Ag]
+[58Co]
+[241Pu]
+[234Th]
+[140La]
+[63Ni]
+[152Eu]
+[132IH]
+[226Rn]
+[154Eu]
+[36ClH]
+[228Ac]
+[155Eu]
+[106Rh]
+[243Am]
+[227Ac]
+[243Cm]
+[236U]
+[144Pr]
+[232U]
+[32SH2]
+[88Y]
+[82BrH]
+[135IH]
+[242Cm]
+[115Cd]
+[242Pu]
+[46Sc]
+[56Mn]
+[234Pa]
+[41Ar]
+[147Nd]
+[187W]
+[151Sm]
+[59Ni]
+[233Pa]
+[52Mn]
+[94Nb]
+[219Rn]
+[236Pu]
+[13NH3]
+[93Zr]
+[51Cr+6]
+[TlH3]
+[123Xe]
+[160Tb]
+[170Tm]
+[182Ta]
+[175Yb]
+[93Mo]
+[143Ce]
+[191Os]
+[126IH]
+[48V]
+[113Cd]
+[47Sc]
+[181Hf]
+[185W]
+[143Pr]
+[191Pt]
+[181W]
+[33PH3]
+[97Ru]
+[97Tc]
+[111Ag]
+[169Er]
+[107Pd]
+[103Ru+2]
+[34SH2]
+[137Ce]
+[242Am]
+[117SnH2]
+[57Ni]
+[239U]
+[60Cu]
+[250Cf]
+[193Au]
+[69Zn]
+[55Co]
+[139Ce]
+[127Xe]
+[159Gd]
+[56Co]
+[177Hf]
+[244Pu]
+[38ClH]
+[142Pr]
+[199Hg]
+[179Hf]
+[178Hf]
+[237U]
+[156Eu]
+[157Eu]
+[105Ru]
+[171Tm]
+[199Au]
+[155Sm]
+[80BrH]
+[108Ag]
+[128IH]
+[48Sc]
+[45Ti]
+[176Lu]
+[121SnH2]
+[148Pm]
+[57Fe]
+[10BH3]
+[96Tc]
+[133IH]
+[143Pm]
+[105Rh]
+[130IH]
+[134IH]
+[131IH]
+[71Zn]
+[105Ag]
+[97Zr]
+[235Pu]
+[231Th]
+[109Pd]
+[93Y]
+[190Ir]
+[135Xe]
+[53Mn]
+[134Ce]
+[234Np]
+[240Am]
+[246Cf]
+[240Cm]
+[241Cm]
+[226Th]
+[39ClH]
+[229Th]
+[245Cm]
+[240U]
+[240Np]
+[249Cm]
+[243Pu]
+[145Pm]
+[199Pt]
+[246Bk]
+[193Pt]
+[230U]
+[250Cm]
+[44Ti]
+[175Hf]
+[254Fm]
+[255Fm]
+[257Fm]
+[92Y]
+[188Ir]
+[171Lu]
+[257Md]
+[247Bk]
+[121IH]
+[250Bk]
+[179Lu]
+[224Ac]
+[195Hg]
+[244Am]
+[246Pu]
+[194Au]
+[252Fm]
+[173Hf]
+[246Cm]
+[135Ce]
+[49Cr]
+[248Cf]
+[247Cm]
+[248Cm]
+[174Ta]
+[176Ta]
+[154Tb]
+[172Ta]
+[177Ta]
+[175Ta]
+[180Ta]
+[158Tb]
+[115Ag]
+[189Os]
+[251Cf]
+[145Pr]
+[147Pr]
+[76BrH]
+[102Rh]
+[238Np]
+[185Os]
+[246Am]
+[233Np]
+[166Dy]
+[254Es]
+[244Cf]
+[193Os]
+[245Am]
+[245Bk]
+[239Am]
+[238Am]
+[97Nb]
+[245Pu]
+[254Cf]
+[188W]
+[250Es]
+[251Es]
+[237Am]
+[182Hf]
+[258Md]
+[232Np]
+[238Cm]
+[60Fe]
+[109Pd+2]
+[234Pu]
+[141Ce+3]
+[136Nd]
+[136Pr]
+[173Ta]
+[110Ru]
+[147Tb]
+[253Fm]
+[139Nd]
+[178Re]
+[177Re]
+[200Au]
+[182Re]
+[156Tb]
+[155Tb]
+[157Tb]
+[161Tb]
+[161Ho]
+[167Tm]
+[173Lu]
+[179Ta]
+[171Er]
+[44Sc]
+[49Sc]
+[49V]
+[51Mn]
+[90Nb]
+[88Nb]
+[88Zr]
+[36SH2]
+[174Yb]
+[178Lu]
+[179W]
+[83BrH]
+[107Cd]
+[75BrH]
+[62Co]
+[48Cr]
+[63Zn]
+[102Ag]
+[154Sm]
+[168Er]
+[65Ni]
+[137La]
+[187Ir]
+[144Pm]
+[146Pm]
+[160Gd]
+[166Yb]
+[162Dy]
+[47V]
+[141Nd]
+[141Sm]
+[166Er]
+[150Sm]
+[146Eu]
+[149Eu]
+[174Lu]
+[17NH3]
+[102Ru]
+[170Hf]
+[188Pt]
+[61Ni]
+[56Ni]
+[149Gd]
+[151Gd]
+[141Pm]
+[147Gd]
+[146Gd]
+[161Er]
+[103Ag]
+[145Eu]
+[153Tb]
+[155Dy]
+[184Re]
+[180Os]
+[182Os]
+[186Pt]
+[181Os]
+[181Re]
+[151Tb]
+[178Ta]
+[178W]
+[189Pt]
+[194Hg]
+[145Sm]
+[150Tb]
+[132La]
+[158Gd]
+[104Ag]
+[193Hg]
+[94Ru]
+[137Pr]
+[155Ho]
+[117Cd]
+[99Ru]
+[146Nd]
+[218Rn]
+[95Y]
+[79Kr]
+[120IH]
+[138Pr]
+[100Pd]
+[166Tm]
+[90Mo]
+[151Nd]
+[231U]
+[138Nd]
+[89Nb]
+[98Nb]
+[162Ho]
+[142Sm]
+[186Ta]
+[104Tc]
+[184Ta]
+[185Ta]
+[170Er]
+[107Rh]
+[131La]
+[169Lu]
+[74BrH]
+[150Pm]
+[172Tm]
+[197Pt]
+[230Pu]
+[170Lu]
+[86Zr]
+[176W]
+[177W]
+[101Pd]
+[105Pd]
+[108Pd]
+[149Nd]
+[164Ho]
+[159Ho]
+[167Ho]
+[176Yb]
+[156Sm]
+[77BrH]
+[189Re]
+[99Rh]
+[100Rh]
+[151Pm]
+[232Pa]
+[228Pa]
+[230Pa]
+[66Ni]
+[194Os]
+[135La]
+[138La]
+[141La]
+[142La]
+[195Ir]
+[96Nb]
+[157Ho]
+[183Hf]
+[162Tm]
+[172Er]
+[148Eu]
+[150Eu]
+[15CH4]
+[89Kr]
+[143La]
+[58Ni]
+[61Co]
+[158Eu]
+[165Er]
+[167Yb]
+[173Tm]
+[175Tm]
+[172Hf]
+[172Lu]
+[93Tc]
+[177Yb]
+[124IH]
+[194Ir]
+[147Eu]
+[101Mo]
+[180Hf]
+[189Ir]
+[87Y]
+[43Sc]
+[195Au]
+[112Ag]
+[84BrH]
+[106Ag]
+[109Ag]
+[101Rh]
+[162Yb]
+[228Rn]
+[139Pr]
+[94Y]
+[201Au]
+[40PH3]
+[110Ag+]
+[104Cd]
+[133Ba+2]
+[226Ac]
+[145Gd]
+[186Ir]
+[184Ir]
+[224Rn]
+[185Ir]
+[182Ir]
+[184Hf]
+[200Pt]
+[227Pa]
+[178Yb]
+[72Br-]
+[72BrH]
+[248Am]
+[238Th]
+[161Gd]
+[35S-2]
+[107Ag]
+[FeH6-4]
+[89Sr]
+[SnH3-]
+[SeH3]
+[TeH3+]
+[SbH4+]
+[AsH4+]
+[4He]
+[AsH3-]
+[1HH]
+[3H+]
+[82Rb]
+[85Sr]
+[90Sr]
+[137Cs]
+[133Ba]
+[131Cs]
+[SbH5]
+[224Ra]
+[22Na]
+[210Bi]
+[214Bi]
+[228Ra]
+[127Sb]
+[136Cs]
+[125Sb]
+[134Cs]
+[140Ba]
+[45Ca]
+[206Pb]
+[207Pb]
+[24Na]
+[86Rb]
+[212Bi]
+[208Pb]
+[124Sb]
+[204Pb]
+[44K]
+[129Te]
+[113Sn]
+[204Tl]
+[87Sr]
+[208Tl]
+[87Rb]
+[47Ca]
+[135Cs]
+[216Po]
+[137Ba]
+[207Bi]
+[212Po]
+[79Se]
+[223Ra]
+[86Sr]
+[122Sb]
+[26Al]
+[32Si]
+[126Sn]
+[225Ra]
+[114In]
+[72Ga]
+[132Te]
+[10Be]
+[125Sn]
+[73As]
+[206Bi]
+[117Sn]
+[40Ca]
+[41Ca]
+[89Rb]
+[116In]
+[129Sb]
+[91Sr]
+[71Ge]
+[139Ba]
+[69Ga]
+[120Sb]
+[121Sn]
+[123Sn]
+[131Te]
+[77Ge]
+[135Ba]
+[82Sr]
+[43K]
+[131Ba]
+[92Sr]
+[88Rb]
+[129Cs]
+[144Cs]
+[127Cs]
+[200Tl]
+[202Tl]
+[141Ba]
+[117Sb]
+[116Sb]
+[78As]
+[131Sb]
+[126Sb]
+[128Sb]
+[130Sb]
+[67Ge]
+[68Ge]
+[78Ge]
+[66Ge]
+[223Fr]
+[132Cs]
+[125Cs]
+[138Cs]
+[133Te]
+[84Rb]
+[83Rb]
+[81Rb]
+[142Ba]
+[200Bi]
+[115Sb]
+[194Tl]
+[70Se]
+[112In]
+[118Sb]
+[70Ga]
+[27Mg]
+[202Bi]
+[83Se]
+[9Li]
+[69As]
+[79Rb]
+[81Sr]
+[83Sr]
+[78Se]
+[109In]
+[29Al]
+[118Sn]
+[117In]
+[119Sb]
+[114Sn]
+[138Ba]
+[69Ge]
+[73Ga]
+[74Ge]
+[206Tl]
+[199Tl]
+[130Cs]
+[28Mg]
+[116Te]
+[112Sn]
+[126Ba]
+[211Bi]
+[81Se]
+[127Sn]
+[143Cs]
+[134Te]
+[80Sr]
+[45K]
+[215Po]
+[207Po]
+[111Sn]
+[211Po]
+[128Ba]
+[198Tl]
+[227Ra]
+[213Po]
+[220Ra]
+[128Sn]
+[203Po]
+[205Po]
+[65Ga]
+[197Tl]
+[88Sr]
+[110In]
+[31Si]
+[201Bi]
+[121Te]
+[205Bi]
+[203Bi]
+[195Tl]
+[209Tl]
+[110Sn]
+[222Fr]
+[207At]
+[119In]
+[As@]
+[129IH]
+[157Dy]
+[111IH]
+[230Ra]
+[144Pr+3]
+[SiH3+]
+[3He]
+[AsH5]
+[72Se]
+[95Tc]
+[103Pd]
+[121Sn+2]
+[211Rn]
+[38SH2]
+[127IH]
+[74Br-]
+[133I-]
+[100Tc+4]
+[100Tc]
+[36Cl-]
+[89Y+3]
+[104Rh]
+[152Sm]
+[226Ra]
+[19FH]
+[104Pd]
+[148Gd]
+[157Lu]
+[33SH2]
+[121I-]
+[17FH]
+[71Se]
+[157Sm]
+[148Tb]
+[164Dy]
+[15OH2]
+[15O+]
+[39K]
+[40Ar]
+[50Cr+3]
+[50Cr]
+[52Ti]
+[103Pd+2]
+[130Ba]
+[142Pm]
+[153Gd+3]
+[151Eu]
+[103Rh]
+[124Xe]
+[152Tb]
+[17OH2]
+[20Ne]
+[52Fe]
+[94Zr+4]
+[94Zr]
+[149Pr]
+[16OH2]
+[53Cr+6]
+[53Cr]
+[81Br-]
+[112Pd]
+[125Xe]
+[155Gd]
+[157Gd]
+[168Yb]
+[184Os]
+[166Tb]
+[221Fr]
+[212Ra]
+[75Br-]
+[79Br-]
+[113Ag]
+[23Na]
+[34Cl-]
+[34ClH]
+[38Cl-]
+[56Fe]
+[68Cu]
+[77Br-]
+[90Zr+4]
+[90Zr]
+[102Pd]
+[154Eu+3]
+[57Mn]
+[165Tm]
+[152Dy]
+[217At]
+[77se]
+[13cH-]
+[122Te]
+[156Gd]
+[124Te]
+[53Ni]
+[131Xe]
+[174Hf+4]
+[174Hf]
+[76Se]
+[168Tm]
+[167Dy]
+[154Gd]
+[95Ru]
+[210At]
+[85Br]
+[59Co]
+[122Xe]
+[27Al]
+[54Cr]
+[198Hg]
+[85Rb+]
+[214Tl]
+[229Rn]
+[218Pb]
+[218Bi]
+[167Tm+3]
+[18o+]
+[P@@H+]
+[P@H+]
+[13N+]
+[212Pb+2]
+[217Bi]
+[249Cf+2]
+[18OH3+]
+[90Sr-]
+[Cf+3]
+[200Hg]
+[86Tc]
+[141Pr+3]
+[141Pr]
+[16nH]
+[14NH4+]
+[132Xe]
+[83Kr]
+[70Zn+2]
+[137Ba+2]
+[36Ar]
+[38Ar]
+[21Ne]
+[126Xe]
+[136Xe]
+[128Xe]
+[134Xe]
+[84Kr]
+[86Kr]
+[78Kr]
+[80Kr]
+[82Kr]
+[67Zn+2]
+[65Cu+2]
+[110Te]
+[58Fe+3]
+[142Nd]
+[38K]
+[198Au+3]
+[122IH]
+[38PH3]
+[130I-]
+[40K+]
+[38K+]
+[28Mg+2]
+[208Tl+]
+[13OH2]
+[198Bi]
+[192Bi]
+[194Bi]
+[196Bi]
+[132I-]
+[83Sr+2]
+[169Er+3]
+[122I-]
+[120I-]
+[92Sr+2]
+[126I-]
+[24Mg]
+[84Sr]
+[118Pd+2]
+[118Pd]
+[AsH4]
+[127I-]
+[9C-]
+[11CH3+]
+[17B]
+[7B]
+[4HH]
+[18C-]
+[22CH3-]
+[22CH4]
+[17C-]
+[15CH3]
+[16CH3]
+[11NH3]
+[21NH3]
+[11N-]
+[11NH]
+[16CH]
+[17CH2]
+[99Ru+2]
+[181Ta+2]
+[181Ta]
+[20CH]
+[32PH2]
+[55Fe+2]
+[SH3]
+[S@H]
+[Mn-]
+[IH4]
+[ThH]
+[GaH-]
+[BiH+]
+[EuH2]
+[FeH4-3]
+[FeH6]
+[IH5]
+[NiH+]
+[SrH2]
+[VH4]
+[YH3]
+[seH+]
+
diff --git a/models/smi_ted/finetune/smi_ted_light/clintox/run_finetune_clintox.sh b/models/smi_ted/finetune/smi_ted_light/clintox/run_finetune_clintox.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c6ffd279c957bef21f52a009c2e363c26c630ea6
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/clintox/run_finetune_clintox.sh
@@ -0,0 +1,23 @@
+python ../../finetune_classification_multitask.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 100 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/clintox' \
+ --dataset_name clintox \
+ --checkpoints_folder './checkpoints_clintox' \
+ --loss_fn 'bceloss' \
+ --target_metric 'roc-auc' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/esol/run_finetune_esol.sh b/models/smi_ted/finetune/smi_ted_light/esol/run_finetune_esol.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dd573cf4bac8c917b001eb157ce160ef4fca9d72
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/esol/run_finetune_esol.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/esol' \
+ --dataset_name esol \
+ --measure_name 'measured log solubility in mols per litre' \
+ --checkpoints_folder './checkpoints_esol' \
+ --loss_fn 'rmse' \
+ --target_metric 'rmse' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/freesolv/run_finetune_freesolv.sh b/models/smi_ted/finetune/smi_ted_light/freesolv/run_finetune_freesolv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..43c1f321e357f873f51460e428e605ca30f2c3e9
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/freesolv/run_finetune_freesolv.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/freesolv' \
+ --dataset_name freesolv \
+ --measure_name 'expt' \
+ --checkpoints_folder './checkpoints_freesolv' \
+ --loss_fn 'rmse' \
+ --target_metric 'rmse' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/hiv/run_finetune_hiv.sh b/models/smi_ted/finetune/smi_ted_light/hiv/run_finetune_hiv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8191fb6960d793e184c5237bb87db0223ee6c888
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/hiv/run_finetune_hiv.sh
@@ -0,0 +1,25 @@
+python ../../finetune_classification.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 1e-7 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/hiv' \
+ --dataset_name hiv \
+ --measure_name 'HIV_active' \
+ --checkpoints_folder './checkpoints_hiv_1e-7' \
+ --loss_fn 'crossentropy' \
+ --target_metric 'roc-auc' \
+ --n_output 2 \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/lipo/run_finetune_lipo.sh b/models/smi_ted/finetune/smi_ted_light/lipo/run_finetune_lipo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fe4932283f31494e03808982800a8d205a4485ec
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/lipo/run_finetune_lipo.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/lipophilicity' \
+ --dataset_name lipophilicity \
+ --measure_name 'y' \
+ --checkpoints_folder './checkpoints_lipophilicity' \
+ --loss_fn 'rmse' \
+ --target_metric 'rmse' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/load.py b/models/smi_ted/finetune/smi_ted_light/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f3aeea2cfe802ae32706cc283f980f2e74ec6a0
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/load.py
@@ -0,0 +1,504 @@
+PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+
+# Transformers
+from fast_transformers.attention import AttentionLayer
+from fast_transformers.events import QKVEvent
+from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer
+from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder
+from fast_transformers.builders.attention_builders import AttentionBuilder
+from fast_transformers.feature_maps import GeneralizedRandomFeatures
+from fast_transformers.masking import LengthMask
+from transformers import BertTokenizer
+
+# Data
+import numpy as np
+import pandas as pd
+
+# Standard library
+from functools import partial
+import regex as re
+import random
+import os
+import gc
+from tqdm import tqdm
+tqdm.pandas()
+
+
+class MolTranBertTokenizer(BertTokenizer):
+ def __init__(self, vocab_file: str = '',
+ do_lower_case=False,
+ unk_token='',
+ sep_token='',
+ pad_token='',
+ cls_token='',
+ mask_token='',
+ **kwargs):
+ super().__init__(vocab_file,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs)
+
+ self.regex_tokenizer = re.compile(PATTERN)
+ self.wordpiece_tokenizer = None
+ self.basic_tokenizer = None
+ with open(vocab_file) as f:
+ self.padding_idx = f.readlines().index(pad_token+'\n')
+
+ def _tokenize(self, text):
+ split_tokens = self.regex_tokenizer.findall(text)
+ return split_tokens
+
+ def convert_idx_to_tokens(self, idx_tensor):
+ tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
+ return tokens
+
+ def convert_tokens_to_string(self, tokens):
+ stopwords = ['', '']
+ clean_tokens = [word for word in tokens if word not in stopwords]
+ out_string = ''.join(clean_tokens)
+ return out_string
+
+ def get_padding_idx(self):
+ return self.padding_idx
+
+ def idx_to_smiles(self, torch_model, idx):
+ '''Convert tokens idx back to SMILES text'''
+ rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
+ flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
+ decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
+ return decoded_smiles
+
+
+## Transformer layers
+class RotaryEmbedding(torch.nn.Module):
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = 0
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self.cos_cached = emb.cos()[None,:, None, :]
+ self.sin_cached = emb.sin()[None,:, None, :]
+
+ return self.cos_cached, self.sin_cached
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+@torch.jit.script
+def apply_rotary_pos_emb(q, k, cos, sin):
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+class RotateAttentionLayer(AttentionLayer):
+ """Rotate attention layer inherits from fast_transformer attention layer.
+ The only thing added is an Embedding encoding, for more information
+ on the attention layer see the fast_transformers code
+ """
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
+ d_values=None, event_dispatcher=""):
+ super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys,
+ d_values=d_values, event_dispatcher=event_dispatcher)
+
+ self.rotaryemb = RotaryEmbedding(d_keys)
+ print('Using Rotation Embedding')
+
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
+ key_lengths):
+ """
+ Using the same frame work as the fast_Transformers attention layer
+ but injecting rotary information to the queries and the keys
+ after the keys and queries are projected.
+ In the argument description we make use of the following sizes
+ - N: the batch size
+ - L: The maximum length of the queries
+ - S: The maximum length of the keys (the actual length per sequence
+ is given by the length mask)
+ - D: The input feature dimensionality passed in the constructor as
+ 'd_model'
+ Arguments
+ ---------
+ queries: (N, L, D) The tensor containing the queries
+ keys: (N, S, D) The tensor containing the keys
+ values: (N, S, D) The tensor containing the values
+ attn_mask: An implementation of BaseMask that encodes where each
+ query can attend to
+ query_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ key_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ Returns
+ -------
+ The new value for each query as a tensor of shape (N, L, D).
+ """
+ # Extract the dimensions into local variables
+ N, L, _ = queries.shape
+ _, S, _ = keys.shape
+ H = self.n_heads
+
+ # Project the queries/keys/values
+ queries = self.query_projection(queries).view(N, L, H, -1)
+ keys = self.key_projection(keys).view(N, S, H, -1)
+ cos, sin = self.rotaryemb(queries)
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
+ values = self.value_projection(values).view(N, S, H, -1)
+ # Let the world know of the qkv
+ self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
+
+
+ # Compute the attention
+ new_values = self.inner_attention(
+ queries,
+ keys,
+ values,
+ attn_mask,
+ query_lengths,
+ key_lengths
+ ).view(N, L, -1)
+
+ # Project the output and return
+ return self.out_projection(new_values)
+
+class RotateEncoderBuilder(BaseTransformerEncoderBuilder):
+ """Build a batch transformer encoder with Relative Rotary embeddings
+ for training or processing of sequences all elements at a time.
+ Example usage:
+ builder = RotateEncoderBuilder()
+ builder.n_layers = 12
+ builder.n_heads = 8
+ builder.feed_forward_dimensions = 1024
+ builder.query_dimensions = 64
+ builder.value_dimensions = 64
+ builder.dropout = 0.1
+ builder.attention_dropout = 0.1
+ builder.attention_type = "linear"
+ transformer = builder.get()
+ """
+ def _get_attention_builder(self):
+ """Return an instance of the appropriate attention builder."""
+ return AttentionBuilder()
+
+ def _get_attention_layer_class(self):
+ """Return the class for the layer that projects queries keys and
+ values."""
+ return RotateAttentionLayer
+
+ def _get_encoder_class(self):
+ """Return the class for the transformer encoder."""
+ return TransformerEncoder
+
+ def _get_encoder_layer_class(self):
+ """Return the class for the transformer encoder layer."""
+ return TransformerEncoderLayer
+
+
+class AutoEncoderLayer(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.encoder = self.Encoder(feature_size, latent_size)
+ self.decoder = self.Decoder(feature_size, latent_size)
+
+ class Encoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(feature_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.lat = nn.Linear(latent_size, latent_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.lat.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.lat(x)
+ return x # -> (N, D)
+
+ class Decoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(latent_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.rec = nn.Linear(latent_size, feature_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.rec.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.rec(x)
+ return x # -> (N, L*D)
+
+
+class LangLayer(nn.Module):
+
+ def __init__(self, n_embd, n_vocab):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.embed = nn.Linear(n_embd, n_embd)
+ self.ln_f = nn.LayerNorm(n_embd)
+ self.head = nn.Linear(n_embd, n_vocab, bias=False)
+
+ def forward(self, tensor):
+ if self.is_cuda_available:
+ self.embed.cuda()
+ self.ln_f.cuda()
+ self.head.cuda()
+ tensor = tensor.cuda()
+ tensor = self.embed(tensor)
+ tensor = F.gelu(tensor)
+ tensor = self.ln_f(tensor)
+ tensor = self.head(tensor)
+ return tensor
+
+
+class Net(nn.Module):
+
+ def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2):
+ super().__init__()
+ self.desc_skip_connection = True
+ self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.relu1 = nn.GELU()
+ self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout2 = nn.Dropout(dropout)
+ self.relu2 = nn.GELU()
+ self.final = nn.Linear(smiles_embed_dim, n_output)
+
+ def forward(self, smiles_emb, multitask=False):
+ x_out = self.fc1(smiles_emb)
+ x_out = self.dropout1(x_out)
+ x_out = self.relu1(x_out)
+
+ if self.desc_skip_connection is True:
+ x_out = x_out + smiles_emb
+
+ z = self.fc2(x_out)
+ z = self.dropout2(z)
+ z = self.relu2(z)
+ if self.desc_skip_connection is True:
+ z = self.final(z + x_out)
+ else:
+ z = self.final(z)
+
+ if multitask:
+ return F.sigmoid(z)
+ return z
+
+
+class MoLEncoder(nn.Module):
+
+ def __init__(self, config, n_vocab, eval=False):
+ super(MoLEncoder, self).__init__()
+
+ # embeddings
+ self.config = config
+ self.tok_emb = nn.Embedding(n_vocab, config['n_embd'])
+ self.drop = nn.Dropout(config['d_dropout'])
+
+ # transformer
+ builder = RotateEncoderBuilder.from_kwargs(
+ n_layers=config['n_layer'],
+ n_heads=config['n_head'],
+ query_dimensions=config['n_embd']//config['n_head'],
+ value_dimensions=config['n_embd']//config['n_head'],
+ feed_forward_dimensions=config['n_embd'],
+ attention_type='linear',
+ # unless we do deterministic_eval here, we will have random outputs
+ feature_map=partial(GeneralizedRandomFeatures,
+ n_dims=config['num_feats'],
+ deterministic_eval=eval),
+ activation='gelu'
+ )
+ self.blocks = builder.get()
+
+ # classification
+ self.lang_model = LangLayer(config['n_embd'], n_vocab)
+
+
+class MoLDecoder(nn.Module):
+
+ def __init__(self, n_vocab, max_len, n_embd, n_gpu=None):
+ super(MoLDecoder, self).__init__()
+
+ self.max_len = max_len
+ self.n_embd = n_embd
+ self.n_gpu = n_gpu
+ self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd)
+ self.lang_model = LangLayer(n_embd, n_vocab)
+
+
+class Smi_ted(nn.Module):
+ """materials.smi-ted-Light 289M Parameters"""
+
+ def __init__(self, tokenizer, config=None, eval=False):
+ super(Smi_ted, self).__init__()
+
+ # configuration
+ self.config = config
+ self.tokenizer = tokenizer
+ self.padding_idx = tokenizer.get_padding_idx()
+ self.n_vocab = len(self.tokenizer.vocab)
+ self.is_cuda_available = torch.cuda.is_available()
+
+ # instantiate modules
+ if self.config:
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
+ self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
+ self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
+
+ def load_checkpoint(self, ckpt_path, n_output, eval=False):
+ # load checkpoint file
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
+
+ # load hyparameters
+ self.config = checkpoint['hparams']
+ self.max_len = self.config['max_len']
+ self.n_embd = self.config['n_embd']
+ self._set_seed(self.config['seed'])
+
+ # instantiate modules
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
+ self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
+ self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
+
+ # load weights
+ if 'state_dict' in checkpoint:
+ if isinstance(checkpoint['state_dict'], list):
+ self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False)
+ self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False)
+ else:
+ self.load_state_dict(checkpoint['state_dict'], strict=False)
+ elif 'MODEL_STATE' in checkpoint:
+ self.load_state_dict(checkpoint['MODEL_STATE'], strict=False)
+
+ # load RNG states each time the model and states are loaded from checkpoint
+ if 'rng' in self.config:
+ rng = self.config['rng']
+ for key, value in rng.items():
+ if key =='torch_state':
+ torch.set_rng_state(value.cpu())
+ elif key =='cuda_state':
+ torch.cuda.set_rng_state(value.cpu())
+ elif key =='numpy_state':
+ np.random.set_state(value)
+ elif key =='python_state':
+ random.setstate(value)
+ else:
+ print('unrecognized state')
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_seed(self, value):
+ print('Random Seed:', value)
+ random.seed(value)
+ torch.manual_seed(value)
+ torch.cuda.manual_seed(value)
+ torch.cuda.manual_seed_all(value)
+ np.random.seed(value)
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+
+ def tokenize(self, smiles):
+ """Tokenize a string into tokens."""
+ if isinstance(smiles, str):
+ batch = [smiles]
+ else:
+ batch = smiles
+
+ tokens = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ max_length=self.max_len,
+ )
+
+ idx = tokens['input_ids'].clone().detach()
+ mask = tokens['attention_mask'].clone().detach()
+
+ if self.is_cuda_available:
+ return idx.cuda(), mask.cuda()
+
+ return idx, mask
+
+ def extract_embeddings(self, smiles):
+ """Extract token and SMILES embeddings."""
+ if self.is_cuda_available:
+ self.encoder.cuda()
+ self.decoder.cuda()
+
+ # tokenizer
+ idx, mask = self.tokenize(smiles)
+
+ # transformer encoder
+ x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.encoder.drop(x)
+ x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1]))
+
+ # add padding
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ mask_embeddings = (token_embeddings * input_mask_expanded)
+ token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0)
+
+ # aggregate token embeddings (similar to mean pooling)
+ # CAUTION: use the embeddings from the autoencoder.
+ smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd))
+
+ return smiles_embeddings
+
+ def __str__(self):
+ return 'smi-ted-Light'
+
+
+def load_smi_ted(folder="./smi_ted_light",
+ ckpt_filename="smi-ted-Light_40.pt",
+ vocab_filename="bert_vocab_curated.txt",
+ n_output=1,
+ eval=False
+ ):
+ tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
+ model = Smi_ted(tokenizer)
+ model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval)
+ print('Vocab size:', len(tokenizer.vocab))
+ print(f'[FINETUNE MODE - {str(model)}]')
+ return model
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..21d42f953b398fe05a6edb592d4d1da9275ec844
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E1-CAM' \
+ --checkpoints_folder './checkpoints_QM8-E1-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..79f7fbade5c91e7d61bd94121423a76b3db600c8
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E1-CC2' \
+ --checkpoints_folder './checkpoints_QM8-E1-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0128628560ad2b08b0959deaeff7ddb1d4d01239
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E1-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E1-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-E1-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..db7eec5338116d1e213a9f81a4704df525105701
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E2-CAM' \
+ --checkpoints_folder './checkpoints_QM8-E2-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f769c4b78585d0f6a3b64e56e4b634ee1a40c3db
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E2-CC2' \
+ --checkpoints_folder './checkpoints_QM8-E2-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..39abacf5d6b2e917b82b45c357363f045dea54c1
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_E2-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'E2-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-E2-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..91ee475d98f5b0dc6719000e1977635f2636e22e
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f1-CAM' \
+ --checkpoints_folder './checkpoints_QM8-f1-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9603905036f6614a5de8a8a71be259d099c50a6e
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f1-CC2' \
+ --checkpoints_folder './checkpoints_QM8-f1-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..874b9d0a357d058a94796c29dc10adf02befc568
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f1-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f1-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-f1-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CAM_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CAM_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..40832c450e3e3fc9b98c9715de8a5a15e5509ded
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CAM_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f2-CAM' \
+ --checkpoints_folder './checkpoints_QM8-f2-CAM' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CC2_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CC2_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e7356e62ec55f32d4902bb8a4a3896ac7849e748
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-CC2_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f2-CC2' \
+ --checkpoints_folder './checkpoints_QM8-f2-CC2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-PBE0_set.sh b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-PBE0_set.sh
new file mode 100644
index 0000000000000000000000000000000000000000..855c5223471ddde0940f05ac6471e9481c727a8d
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm8/run_finetune_f2-PBE0_set.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 16 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-6 \
+ --lr_multiplier 1 \
+ --max_epochs 720 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm8' \
+ --dataset_name qm8 \
+ --measure_name 'f2-PBE0' \
+ --checkpoints_folder './checkpoints_QM8-f2-PBE0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_alpha.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_alpha.sh
new file mode 100644
index 0000000000000000000000000000000000000000..62111cac55825827603340cc5d5bc45218339ad5
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_alpha.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'alpha' \
+ --checkpoints_folder './checkpoints_QM9-alpha' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_cv.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_cv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f499840a8fea8821d7703efe79bb07661b5767df
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_cv.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'cv' \
+ --checkpoints_folder './checkpoints_QM9-cv' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_g298.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_g298.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fd001dda63f72fc301b11072a06d8dc62e54b5e2
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_g298.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'g298' \
+ --checkpoints_folder './checkpoints_QM9-g298' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_gap.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_gap.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8031170234e71cc55af842b6231fd448bdf34b99
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_gap.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'gap' \
+ --checkpoints_folder './checkpoints_QM9-gap' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_h298.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_h298.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c1945ad80e05d6916ec35c264361b938cd1333f0
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_h298.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'h298' \
+ --checkpoints_folder './checkpoints_QM9-h298' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_homo.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_homo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cb9f4a6fab256465e60ee21530957198e8160f58
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_homo.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'homo' \
+ --checkpoints_folder './checkpoints_QM9-homo' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_lumo.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_lumo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d012bd7167b4997290a0ee0659f988748e9f83e7
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_lumo.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'lumo' \
+ --checkpoints_folder './checkpoints_QM9-lumo' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_mu.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_mu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ac604c0c050401deace608447fb1f7089a4af4b6
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_mu.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'mu' \
+ --checkpoints_folder './checkpoints_QM9-mu' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_r2.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_r2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d688eb79d433dcda120bda929fbfd410d6c193bb
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_r2.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'r2' \
+ --checkpoints_folder './checkpoints_QM9-r2' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u0.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u0.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1ff6506190997a57fc0ce9655e9e358984204ff5
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u0.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'u0' \
+ --checkpoints_folder './checkpoints_QM9-u0' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u298.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u298.sh
new file mode 100644
index 0000000000000000000000000000000000000000..880c6a8f6c0a351218efe17bbdcd2581bd0dd6f8
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_u298.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'u298' \
+ --checkpoints_folder './checkpoints_QM9-u298' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_zpve.sh b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_zpve.sh
new file mode 100644
index 0000000000000000000000000000000000000000..45adaf0769fddb8e8761d4ffe24c26a47b3940a0
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/qm9/run_finetune_zpve.sh
@@ -0,0 +1,24 @@
+python ../../finetune_regression.py \
+ --n_batch 128 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/qm9' \
+ --dataset_name qm9 \
+ --measure_name 'zpve' \
+ --checkpoints_folder './checkpoints_QM9-zpve' \
+ --loss_fn 'mae' \
+ --target_metric 'mae' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 1 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/sider/run_finetune_sider.sh b/models/smi_ted/finetune/smi_ted_light/sider/run_finetune_sider.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bb9d03d7d920bd3c54f97343f868b1859cea1e7c
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/sider/run_finetune_sider.sh
@@ -0,0 +1,23 @@
+python ../../finetune_classification_multitask.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 500 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/sider' \
+ --dataset_name sider \
+ --checkpoints_folder './checkpoints_sider' \
+ --loss_fn 'bceloss' \
+ --target_metric 'roc-auc' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/smi_ted_light/tox21/run_finetune_tox21.sh b/models/smi_ted/finetune/smi_ted_light/tox21/run_finetune_tox21.sh
new file mode 100644
index 0000000000000000000000000000000000000000..46a37d65e3d05ae0c95faed776c5cb90829726a5
--- /dev/null
+++ b/models/smi_ted/finetune/smi_ted_light/tox21/run_finetune_tox21.sh
@@ -0,0 +1,23 @@
+python ../../finetune_classification_multitask.py \
+ --n_batch 32 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.1 \
+ --dropout 0.1 \
+ --lr_start 3e-5 \
+ --lr_multiplier 1 \
+ --max_epochs 100 \
+ --num_feats 32 \
+ --smi_ted_version 'v1' \
+ --model_path '../' \
+ --ckpt_filename 'smi-ted-Light_40.pt' \
+ --data_root '../../moleculenet/tox21' \
+ --dataset_name tox21 \
+ --checkpoints_folder './checkpoints_tox21' \
+ --loss_fn 'bceloss' \
+ --target_metric 'roc-auc' \
+ --save_ckpt 1 \
+ --start_seed 0 \
+ --train_decoder 0 \
\ No newline at end of file
diff --git a/models/smi_ted/finetune/trainers.py b/models/smi_ted/finetune/trainers.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db4917c03d77b5af472b3727d209bf701336ca5
--- /dev/null
+++ b/models/smi_ted/finetune/trainers.py
@@ -0,0 +1,591 @@
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from utils import CustomDataset, CustomDatasetMultitask, RMSELoss, normalize_smiles
+
+# Data
+import pandas as pd
+import numpy as np
+
+# Standard library
+import random
+import args
+import os
+import shutil
+from tqdm import tqdm
+
+# Machine Learning
+from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score, roc_auc_score, roc_curve, auc, precision_recall_curve
+from scipy import stats
+from utils import RMSE, sensitivity, specificity
+
+
+class Trainer:
+
+ def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
+ # data
+ self.df_train = raw_data[0]
+ self.df_valid = raw_data[1]
+ self.df_test = raw_data[2]
+ self.dataset_name = dataset_name
+ self.target = target
+ self.batch_size = batch_size
+ self.hparams = hparams
+ self._prepare_data()
+
+ # config
+ self.target_metric = target_metric
+ self.seed = seed
+ self.smi_ted_version = smi_ted_version
+ self.checkpoints_folder = checkpoints_folder
+ self.restart_filename = restart_filename
+ self.start_epoch = 1
+ self.save_every_epoch = save_every_epoch
+ self.save_ckpt = save_ckpt
+ self.device = device
+ self.best_vloss = float('inf')
+ self.last_filename = None
+ self._set_seed(seed)
+
+ def _prepare_data(self):
+ # normalize dataset
+ self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles)
+ self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles)
+ self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles)
+
+ self.df_train = self.df_train.dropna(subset=['canon_smiles'])
+ self.df_valid = self.df_valid.dropna(subset=['canon_smiles'])
+ self.df_test = self.df_test.dropna(subset=['canon_smiles'])
+
+ # create dataloader
+ self.train_loader = DataLoader(
+ CustomDataset(self.df_train, self.target),
+ batch_size=self.batch_size,
+ shuffle=True,
+ pin_memory=True
+ )
+ self.valid_loader = DataLoader(
+ CustomDataset(self.df_valid, self.target),
+ batch_size=self.batch_size,
+ shuffle=False,
+ pin_memory=True
+ )
+ self.test_loader = DataLoader(
+ CustomDataset(self.df_test, self.target),
+ batch_size=self.batch_size,
+ shuffle=False,
+ pin_memory=True
+ )
+
+ def compile(self, model, optimizer, loss_fn):
+ self.model = model
+ self.optimizer = optimizer
+ self.loss_fn = loss_fn
+ self._print_configuration()
+ if self.restart_filename:
+ self._load_checkpoint(self.restart_filename)
+ print('Checkpoint restored!')
+
+ def fit(self, max_epochs=500):
+ for epoch in range(self.start_epoch, max_epochs+1):
+ print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
+
+ # training
+ self.model.to(self.device)
+ self.model.train()
+ train_loss = self._train_one_epoch()
+
+ # validation
+ self.model.eval()
+ val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader)
+ for m in val_metrics.keys():
+ print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
+
+ ############################### Save Finetune checkpoint #######################################
+ if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt:
+ # remove old checkpoint
+ if (self.last_filename != None) and (not self.save_every_epoch):
+ os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
+
+ # filename
+ model_name = f'{str(self.model)}-Finetune'
+ self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
+
+ # update best loss
+ self.best_vloss = val_loss
+
+ # save checkpoint
+ print('Saving checkpoint...')
+ self._save_checkpoint(epoch, self.last_filename)
+
+ def evaluate(self, verbose=True):
+ if verbose:
+ print("\n=====Test Evaluation=====")
+
+ if self.smi_ted_version == 'v1':
+ import smi_ted_light.load as load
+ elif self.smi_ted_version == 'v2':
+ import smi_ted_large.load as load
+ else:
+ raise Exception('Please, specify the SMI-TED version: `v1` or `v2`.')
+
+ # copy vocabulary to checkpoint folder
+ if not os.path.exists(os.path.join(self.checkpoints_folder, 'bert_vocab_curated.txt')):
+ smi_ted_path = os.path.dirname(load.__file__)
+ shutil.copy(os.path.join(smi_ted_path, 'bert_vocab_curated.txt'), self.checkpoints_folder)
+
+ # load model for inference
+ model_inf = load.load_smi_ted(
+ folder=self.checkpoints_folder,
+ ckpt_filename=self.last_filename,
+ eval=True,
+ ).to(self.device)
+
+ # set model evaluation mode
+ model_inf.eval()
+
+ # evaluate on test set
+ tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf)
+
+ if verbose:
+ # show metrics
+ for m in tst_metrics.keys():
+ print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
+
+ # save predictions
+ pd.DataFrame(tst_preds).to_csv(
+ os.path.join(
+ self.checkpoints_folder,
+ f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
+ index=False
+ )
+
+ def _train_one_epoch(self):
+ raise NotImplementedError
+
+ def _validate_one_epoch(self, data_loader, model=None):
+ raise NotImplementedError
+
+ def _print_configuration(self):
+ print('----Finetune information----')
+ print('Dataset:\t', self.dataset_name)
+ print('Target:\t\t', self.target)
+ print('Batch size:\t', self.batch_size)
+ print('LR:\t\t', self._get_lr())
+ print('Device:\t\t', self.device)
+ print('Optimizer:\t', self.optimizer.__class__.__name__)
+ print('Loss function:\t', self.loss_fn.__class__.__name__)
+ print('Seed:\t\t', self.seed)
+ print('Train size:\t', self.df_train.shape[0])
+ print('Valid size:\t', self.df_valid.shape[0])
+ print('Test size:\t', self.df_test.shape[0])
+
+ def _load_checkpoint(self, filename):
+ ckpt_path = os.path.join(self.checkpoints_folder, filename)
+ ckpt_dict = torch.load(ckpt_path, map_location='cpu')
+ self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
+ self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
+ self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
+
+ def _save_checkpoint(self, current_epoch, filename):
+ if not os.path.exists(self.checkpoints_folder):
+ os.makedirs(self.checkpoints_folder)
+
+ ckpt_dict = {
+ 'MODEL_STATE': self.model.state_dict(),
+ 'EPOCHS_RUN': current_epoch,
+ 'hparams': vars(self.hparams),
+ 'finetune_info': {
+ 'dataset': self.dataset_name,
+ 'target`': self.target,
+ 'batch_size': self.batch_size,
+ 'lr': self._get_lr(),
+ 'device': self.device,
+ 'optim': self.optimizer.__class__.__name__,
+ 'loss_fn': self.loss_fn.__class__.__name__,
+ 'train_size': self.df_train.shape[0],
+ 'valid_size': self.df_valid.shape[0],
+ 'test_size': self.df_test.shape[0],
+ 'best_vloss': self.best_vloss,
+ },
+ 'seed': self.seed,
+ }
+
+ assert list(ckpt_dict.keys()) == ['MODEL_STATE', 'EPOCHS_RUN', 'hparams', 'finetune_info', 'seed']
+
+ torch.save(ckpt_dict, os.path.join(self.checkpoints_folder, filename))
+
+ def _set_seed(self, value):
+ random.seed(value)
+ torch.manual_seed(value)
+ np.random.seed(value)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(value)
+ torch.cuda.manual_seed_all(value)
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+
+ def _get_lr(self):
+ for param_group in self.optimizer.param_groups:
+ return param_group['lr']
+
+
+class TrainerRegressor(Trainer):
+
+ def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
+ super().__init__(raw_data, dataset_name, target, batch_size, hparams,
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
+
+ def _train_one_epoch(self):
+ running_loss = 0.0
+
+ for idx, data in enumerate(pbar := tqdm(self.train_loader)):
+ # Every data instance is an input + label pair
+ smiles, targets = data
+ targets = targets.clone().detach().to(self.device)
+
+ # zero the parameter gradients (otherwise they are accumulated)
+ self.optimizer.zero_grad()
+
+ # Make predictions for this batch
+ embeddings = self.model.extract_embeddings(smiles).to(self.device)
+ outputs = self.model.net(embeddings).squeeze()
+
+ # Compute the loss and its gradients
+ loss = self.loss_fn(outputs, targets)
+ loss.backward()
+
+ # Adjust learning weights
+ self.optimizer.step()
+
+ # print statistics
+ running_loss += loss.item()
+
+ # progress bar
+ pbar.set_description('[TRAINING]')
+ pbar.set_postfix(loss=running_loss/(idx+1))
+ pbar.refresh()
+
+ return running_loss / len(self.train_loader)
+
+ def _validate_one_epoch(self, data_loader, model=None):
+ data_targets = []
+ data_preds = []
+ running_loss = 0.0
+
+ model = self.model if model is None else model
+
+ with torch.no_grad():
+ for idx, data in enumerate(pbar := tqdm(data_loader)):
+ # Every data instance is an input + label pair
+ smiles, targets = data
+ targets = targets.clone().detach().to(self.device)
+
+ # Make predictions for this batch
+ embeddings = model.extract_embeddings(smiles).to(self.device)
+ predictions = model.net(embeddings).squeeze()
+
+ # Compute the loss
+ loss = self.loss_fn(predictions, targets)
+
+ data_targets.append(targets.view(-1))
+ data_preds.append(predictions.view(-1))
+
+ # print statistics
+ running_loss += loss.item()
+
+ # progress bar
+ pbar.set_description('[EVALUATION]')
+ pbar.set_postfix(loss=running_loss/(idx+1))
+ pbar.refresh()
+
+ # Put together predictions and labels from batches
+ preds = torch.cat(data_preds, dim=0).cpu().numpy()
+ tgts = torch.cat(data_targets, dim=0).cpu().numpy()
+
+ # Compute metrics
+ mae = mean_absolute_error(tgts, preds)
+ r2 = r2_score(tgts, preds)
+ rmse = RMSE(preds, tgts)
+ spearman = stats.spearmanr(tgts, preds).statistic # scipy 1.12.0
+
+ # Rearange metrics
+ metrics = {
+ 'mae': mae,
+ 'r2': r2,
+ 'rmse': rmse,
+ 'spearman': spearman,
+ }
+
+ return preds, running_loss / len(data_loader), metrics
+
+
+class TrainerClassifier(Trainer):
+
+ def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
+ super().__init__(raw_data, dataset_name, target, batch_size, hparams,
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
+
+ def _train_one_epoch(self):
+ running_loss = 0.0
+
+ for idx, data in enumerate(pbar := tqdm(self.train_loader)):
+ # Every data instance is an input + label pair
+ smiles, targets = data
+ targets = targets.clone().detach().to(self.device)
+
+ # zero the parameter gradients (otherwise they are accumulated)
+ self.optimizer.zero_grad()
+
+ # Make predictions for this batch
+ embeddings = self.model.extract_embeddings(smiles).to(self.device)
+ outputs = self.model.net(embeddings).squeeze()
+
+ # Compute the loss and its gradients
+ loss = self.loss_fn(outputs, targets.long())
+ loss.backward()
+
+ # Adjust learning weights
+ self.optimizer.step()
+
+ # print statistics
+ running_loss += loss.item()
+
+ # progress bar
+ pbar.set_description('[TRAINING]')
+ pbar.set_postfix(loss=running_loss/(idx+1))
+ pbar.refresh()
+
+ return running_loss / len(self.train_loader)
+
+ def _validate_one_epoch(self, data_loader, model=None):
+ data_targets = []
+ data_preds = []
+ running_loss = 0.0
+
+ model = self.model if model is None else model
+
+ with torch.no_grad():
+ for idx, data in enumerate(pbar := tqdm(data_loader)):
+ # Every data instance is an input + label pair
+ smiles, targets = data
+ targets = targets.clone().detach().to(self.device)
+
+ # Make predictions for this batch
+ embeddings = model.extract_embeddings(smiles).to(self.device)
+ predictions = model.net(embeddings).squeeze()
+
+ # Compute the loss
+ loss = self.loss_fn(predictions, targets.long())
+
+ data_targets.append(targets.view(-1))
+ data_preds.append(predictions)
+
+ # print statistics
+ running_loss += loss.item()
+
+ # progress bar
+ pbar.set_description('[EVALUATION]')
+ pbar.set_postfix(loss=running_loss/(idx+1))
+ pbar.refresh()
+
+ # Put together predictions and labels from batches
+ preds = torch.cat(data_preds, dim=0).cpu().numpy()
+ tgts = torch.cat(data_targets, dim=0).cpu().numpy()
+
+ # Compute metrics
+ preds_cpu = F.softmax(torch.tensor(preds), dim=1).cpu().numpy()[:, 1]
+
+ # accuracy
+ y_pred = np.where(preds_cpu >= 0.5, 1, 0)
+ accuracy = accuracy_score(tgts, y_pred)
+
+ # sensitivity
+ sn = sensitivity(tgts, y_pred)
+
+ # specificity
+ sp = specificity(tgts, y_pred)
+
+ # roc-auc
+ fpr, tpr, _ = roc_curve(tgts, preds_cpu)
+ roc_auc = auc(fpr, tpr)
+
+ # prc-auc
+ precision, recall, _ = precision_recall_curve(tgts, preds_cpu)
+ prc_auc = auc(recall, precision)
+
+ # Rearange metrics
+ metrics = {
+ 'acc': accuracy,
+ 'roc-auc': roc_auc,
+ 'prc-auc': prc_auc,
+ 'sensitivity': sn,
+ 'specificity': sp,
+ }
+
+ return preds, running_loss / len(data_loader), metrics
+
+
+class TrainerClassifierMultitask(Trainer):
+
+ def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
+ super().__init__(raw_data, dataset_name, target, batch_size, hparams,
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
+
+ def _prepare_data(self):
+ # normalize dataset
+ self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles)
+ self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles)
+ self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles)
+
+ self.df_train = self.df_train.dropna(subset=['canon_smiles'])
+ self.df_valid = self.df_valid.dropna(subset=['canon_smiles'])
+ self.df_test = self.df_test.dropna(subset=['canon_smiles'])
+
+ # create dataloader
+ self.train_loader = DataLoader(
+ CustomDatasetMultitask(self.df_train, self.target),
+ batch_size=self.batch_size,
+ shuffle=True,
+ pin_memory=True
+ )
+ self.valid_loader = DataLoader(
+ CustomDatasetMultitask(self.df_valid, self.target),
+ batch_size=self.batch_size,
+ shuffle=False,
+ pin_memory=True
+ )
+ self.test_loader = DataLoader(
+ CustomDatasetMultitask(self.df_test, self.target),
+ batch_size=self.batch_size,
+ shuffle=False,
+ pin_memory=True
+ )
+
+ def _train_one_epoch(self):
+ running_loss = 0.0
+
+ for idx, data in enumerate(pbar := tqdm(self.train_loader)):
+ # Every data instance is an input + label pair + mask
+ smiles, targets, target_masks = data
+ targets = targets.clone().detach().to(self.device)
+
+ # zero the parameter gradients (otherwise they are accumulated)
+ self.optimizer.zero_grad()
+
+ # Make predictions for this batch
+ embeddings = self.model.extract_embeddings(smiles).to(self.device)
+ outputs = self.model.net(embeddings, multitask=True).squeeze()
+ outputs = outputs * target_masks.to(self.device)
+
+ # Compute the loss and its gradients
+ loss = self.loss_fn(outputs, targets)
+ loss.backward()
+
+ # Adjust learning weights
+ self.optimizer.step()
+
+ # print statistics
+ running_loss += loss.item()
+
+ # progress bar
+ pbar.set_description('[TRAINING]')
+ pbar.set_postfix(loss=running_loss/(idx+1))
+ pbar.refresh()
+
+ return running_loss / len(self.train_loader)
+
+ def _validate_one_epoch(self, data_loader, model=None):
+ data_targets = []
+ data_preds = []
+ data_masks = []
+ running_loss = 0.0
+
+ model = self.model if model is None else model
+
+ with torch.no_grad():
+ for idx, data in enumerate(pbar := tqdm(data_loader)):
+ # Every data instance is an input + label pair + mask
+ smiles, targets, target_masks = data
+ targets = targets.clone().detach().to(self.device)
+
+ # Make predictions for this batch
+ embeddings = model.extract_embeddings(smiles).to(self.device)
+ predictions = model.net(embeddings, multitask=True).squeeze()
+ predictions = predictions * target_masks.to(self.device)
+
+ # Compute the loss
+ loss = self.loss_fn(predictions, targets)
+
+ data_targets.append(targets)
+ data_preds.append(predictions)
+ data_masks.append(target_masks)
+
+ # print statistics
+ running_loss += loss.item()
+
+ # progress bar
+ pbar.set_description('[EVALUATION]')
+ pbar.set_postfix(loss=running_loss/(idx+1))
+ pbar.refresh()
+
+ # Put together predictions and labels from batches
+ preds = torch.cat(data_preds, dim=0)
+ tgts = torch.cat(data_targets, dim=0)
+ mask = torch.cat(data_masks, dim=0)
+ mask = mask > 0
+
+ # Compute metrics
+ roc_aucs = []
+ prc_aucs = []
+ sns = []
+ sps = []
+ num_tasks = len(self.target)
+ for idx in range(num_tasks):
+ actuals_task = torch.masked_select(tgts[:, idx], mask[:, idx].to(self.device))
+ preds_task = torch.masked_select(preds[:, idx], mask[:, idx].to(self.device))
+
+ # accuracy
+ y_pred = np.where(preds_task.cpu().detach() >= 0.5, 1, 0)
+ accuracy = accuracy_score(actuals_task.cpu().numpy(), y_pred)
+
+ # sensitivity
+ sn = sensitivity(actuals_task.cpu().numpy(), y_pred)
+
+ # specificity
+ sp = specificity(actuals_task.cpu().numpy(), y_pred)
+
+ # roc-auc
+ roc_auc = roc_auc_score(actuals_task.cpu().numpy(), preds_task.cpu().numpy())
+
+ # prc-auc
+ precision, recall, thresholds = precision_recall_curve(actuals_task.cpu().numpy(), preds_task.cpu().numpy())
+ prc_auc = auc(recall, precision)
+
+ # append
+ sns.append(sn)
+ sps.append(sp)
+ roc_aucs.append(roc_auc)
+ prc_aucs.append(prc_auc)
+ average_sn = torch.mean(torch.tensor(sns))
+ average_sp = torch.mean(torch.tensor(sps))
+ average_roc_auc = torch.mean(torch.tensor(roc_aucs))
+ average_prc_auc = torch.mean(torch.tensor(prc_aucs))
+
+ # Rearange metrics
+ metrics = {
+ 'acc': accuracy,
+ 'roc-auc': average_roc_auc.item(),
+ 'prc-auc': average_prc_auc.item(),
+ 'sensitivity': average_sn.item(),
+ 'specificity': average_sp.item(),
+ }
+
+ return preds.cpu().numpy(), running_loss / len(data_loader), metrics
\ No newline at end of file
diff --git a/models/smi_ted/finetune/utils.py b/models/smi_ted/finetune/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b0fd79adf4edd67417138cb258c91ab0201b5de
--- /dev/null
+++ b/models/smi_ted/finetune/utils.py
@@ -0,0 +1,115 @@
+# Deep learning
+import torch
+from torch.utils.data import Dataset
+from sklearn.metrics import confusion_matrix
+
+# Data
+import pandas as pd
+import numpy as np
+
+# Standard library
+import os
+
+# Chemistry
+from rdkit import Chem
+from rdkit.Chem import PandasTools
+from rdkit.Chem import Descriptors
+PandasTools.RenderImagesInAllDataFrames(True)
+
+
+def normalize_smiles(smi, canonical=True, isomeric=False):
+ try:
+ normalized = Chem.MolToSmiles(
+ Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
+ )
+ except:
+ normalized = None
+ return normalized
+
+
+class RMSELoss:
+ def __init__(self):
+ pass
+
+ def __call__(self, yhat, y):
+ return torch.sqrt(torch.mean((yhat-y)**2))
+
+
+def RMSE(predictions, targets):
+ return np.sqrt(((predictions - targets) ** 2).mean())
+
+
+def sensitivity(y_true, y_pred):
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
+ return (tp/(tp+fn))
+
+
+def specificity(y_true, y_pred):
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
+ return (tn/(tn+fp))
+
+
+def get_optim_groups(module, keep_decoder=False):
+ # setup optimizer
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear,)
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in module.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+ if not keep_decoder and 'decoder' in fpn: # exclude decoder components
+ continue
+
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in module.named_parameters()}
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+ ]
+
+ return optim_groups
+
+
+class CustomDataset(Dataset):
+ def __init__(self, dataset, target):
+ self.dataset = dataset
+ self.target = target
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ smiles = self.dataset['canon_smiles'].iloc[idx]
+ labels = self.dataset[self.target].iloc[idx]
+ return smiles, labels
+
+
+class CustomDatasetMultitask(Dataset):
+ def __init__(self, dataset, targets):
+ self.dataset = dataset
+ self.targets = targets
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ smiles = self.dataset['canon_smiles'].iloc[idx]
+ labels = self.dataset[self.targets].iloc[idx].to_numpy()
+ mask = [0.0 if np.isnan(x) else 1.0 for x in labels]
+ labels = [0.0 if np.isnan(x) else x for x in labels]
+ return smiles, torch.tensor(labels, dtype=torch.float32), torch.tensor(mask)
\ No newline at end of file
diff --git a/models/smi_ted/images/smi-ted.png b/models/smi_ted/images/smi-ted.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f1688456dc60e5adef22533de0fc0b1f0e3b561
--- /dev/null
+++ b/models/smi_ted/images/smi-ted.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa41339c9e8c14412f05dcaa5f42d4d185e101dff4d97b82749cedf671678a71
+size 1891667
diff --git a/models/smi_ted/inference/smi_ted_large/bert_vocab_curated.txt b/models/smi_ted/inference/smi_ted_large/bert_vocab_curated.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd
--- /dev/null
+++ b/models/smi_ted/inference/smi_ted_large/bert_vocab_curated.txt
@@ -0,0 +1,2393 @@
+
+
+
+
+C
+c
+(
+)
+1
+O
+N
+2
+=
+n
+3
+[C@H]
+[C@@H]
+F
+S
+4
+Cl
+-
+o
+s
+[nH]
+#
+/
+Br
+[C@]
+[C@@]
+[N+]
+[O-]
+5
+\
+.
+I
+6
+[S@]
+[S@@]
+P
+[N-]
+[Si]
+7
+[n+]
+[2H]
+8
+[NH+]
+B
+9
+[C-]
+[Na+]
+[Cl-]
+[c-]
+[CH]
+%10
+[NH2+]
+[P+]
+[B]
+[I-]
+%11
+[CH2-]
+[O+]
+[NH3+]
+[C]
+[Br-]
+[IH2]
+[S-]
+[cH-]
+%12
+[nH+]
+[B-]
+[K+]
+[Sn]
+[Se]
+[CH-]
+[HH]
+[Y]
+[n-]
+[CH3-]
+[SiH]
+[S+]
+%13
+[SiH2]
+[Li+]
+[NH-]
+%14
+[Na]
+[CH2]
+[O-2]
+[U+2]
+[W]
+[Al]
+[P@]
+[Fe+2]
+[PH+]
+%15
+[Cl+3]
+[Zn+2]
+[Ir]
+[Mg+2]
+[Pt+2]
+[OH2+]
+[As]
+[Fe]
+[OH+]
+[Zr+2]
+[3H]
+[Ge]
+[SiH3]
+[OH-]
+[NH4+]
+[Cu+2]
+[P@@]
+p
+[Pt]
+%16
+[Ca+2]
+[Zr]
+[F-]
+[C+]
+[Ti]
+[P-]
+[V]
+[se]
+[U]
+[O]
+[Ni+2]
+[Zn]
+[Co]
+[Ni]
+[Pd+2]
+[Cu]
+%17
+[Cu+]
+[Te]
+[H+]
+[CH+]
+[Li]
+[Pd]
+[Mo]
+[Ru+2]
+[o+]
+[Re]
+[SH+]
+%18
+[Ac]
+[Cr]
+[NH2-]
+[K]
+[13CH2]
+[c]
+[Zr+4]
+[Tl]
+[13C]
+[Mn]
+[N@+]
+[Hg]
+[Rh]
+[Ti+4]
+[Sb]
+[Co+2]
+[Ag+]
+[Ru]
+%19
+[N@@+]
+[Ti+2]
+[Al+3]
+[Pb]
+[I+]
+[18F]
+[s+]
+[Rb+]
+[Ba+2]
+[H-]
+[Fe+3]
+[Ir+3]
+[13cH]
+%20
+[AlH2]
+[Au+]
+[13c]
+[SH2+]
+[Sn+2]
+[Mn+2]
+[Si-]
+[Ag]
+[N]
+[Bi]
+%21
+[In]
+[CH2+]
+[Y+3]
+[Ga]
+%22
+[Co+3]
+[Au]
+[13CH3]
+[Mg]
+[Cs+]
+[W+2]
+[Hf]
+[Zn+]
+[Se-]
+[S-2]
+[Ca]
+[pH]
+[ClH+]
+[Ti+3]
+%23
+[Ru+]
+[SH-]
+[13CH]
+[IH+]
+[Hf+4]
+[Rf]
+[OH3+]
+%24
+[Pt+4]
+[Zr+3]
+[PH3+]
+[Sr+2]
+[Cd+2]
+[Cd]
+%25
+[Os]
+[BH-]
+[Sn+4]
+[Cr+3]
+[Ru+3]
+[PH2+]
+[Rh+2]
+[V+2]
+%26
+[Gd+3]
+[Pb+2]
+[PH]
+[Hg+]
+[Mo+2]
+[AlH]
+[Sn+]
+%27
+[Pd+]
+b
+[Rh+3]
+[Hg+2]
+[15NH]
+[14C]
+%28
+[Mn+3]
+[Si+]
+[SeH]
+[13C@H]
+[NH]
+[Ga+3]
+[SiH-]
+[13C@@H]
+[Ce]
+[Au+3]
+[Bi+3]
+[15N]
+%29
+[BH3-]
+[14cH]
+[Ti+]
+[Gd]
+[cH+]
+[Cr+2]
+[Sb-]
+%30
+[Be+2]
+[Al+]
+[te]
+[11CH3]
+[Sm]
+[Pr]
+[La]
+%31
+[Al-]
+[Ta]
+[125I]
+[BH2-]
+[Nb]
+[Si@]
+%32
+[14c]
+[Sb+3]
+[Ba]
+%33
+[Os+2]
+[Si@@]
+[La+3]
+[15n]
+[15NH2]
+[Nd+3]
+%34
+[14CH2]
+[18O]
+[Nd]
+[GeH]
+[Ni+3]
+[Eu]
+[Dy+3]
+[Sc]
+%36
+[Se-2]
+[As+]
+%35
+[AsH]
+[Tb]
+[Sb+5]
+[Se+]
+[Ce+3]
+[c+]
+[In+3]
+[SnH]
+[Mo+4]
+%37
+[V+4]
+[Eu+3]
+[Hf+2]
+%38
+[Pt+]
+[p+]
+[123I]
+[Tl+]
+[Sm+3]
+%39
+[Yb+3]
+%40
+[Yb]
+[Os+]
+%41
+[10B]
+[Sc+3]
+[Al+2]
+%42
+[Sr]
+[Tb+3]
+[Po]
+[Tc]
+[PH-]
+[AlH3]
+[Ar]
+[U+4]
+[SnH2]
+[Cl+2]
+[si]
+[Fe+]
+[14CH3]
+[U+3]
+[Cl+]
+%43
+[GeH2]
+%44
+[Er+3]
+[Mo+3]
+[I+2]
+[Fe+4]
+[99Tc]
+%45
+[11C]
+%46
+[SnH3]
+[S]
+[Te+]
+[Er]
+[Lu+3]
+[11B]
+%47
+%48
+[P]
+[Tm]
+[Th]
+[Dy]
+[Pr+3]
+[Ta+5]
+[Nb+5]
+[Rb]
+[GeH3]
+[Br+2]
+%49
+[131I]
+[Fm]
+[Cs]
+[BH4-]
+[Lu]
+[15nH]
+%50
+[Ru+6]
+[b-]
+[Ho]
+[Th+4]
+[Ru+4]
+%52
+[14CH]
+%51
+[Cr+6]
+[18OH]
+[Ho+3]
+[Ce+4]
+[Bi+2]
+[Co+]
+%53
+[Yb+2]
+[Fe+6]
+[Be]
+%54
+[SH3+]
+[Np]
+[As-]
+%55
+[14C@@H]
+[Ir+2]
+[GaH3]
+[p-]
+[GeH4]
+[Sn+3]
+[Os+4]
+%56
+[14C@H]
+[sH+]
+[19F]
+[Eu+2]
+[TlH]
+%57
+[Cr+4]
+%58
+[B@@-]
+[SiH+]
+[At]
+[Am]
+[Fe+5]
+[AsH2]
+[Si+4]
+[B@-]
+[Pu]
+[SbH]
+[P-2]
+[Tm+3]
+*
+%59
+[se+]
+[IH-]
+%60
+[oH+]
+[1H]
+[15N+]
+[124I]
+[S@@+]
+[P-3]
+[H]
+[IH2+]
+[TeH]
+[Xe]
+[PH4+]
+[Cr+]
+[Cm]
+[I+3]
+%61
+[Nb+2]
+[Ru+5]
+%62
+[Ta+2]
+[Tc+4]
+[CH3+]
+[Pm]
+[Si@H]
+[No]
+%63
+[Cr+5]
+[Th+2]
+[Zn-2]
+[13C@]
+[Lr]
+%64
+[99Tc+3]
+%65
+[13C@@]
+%66
+[Fe-]
+[17O]
+[siH]
+[Sb+]
+[OH]
+[IH]
+[11CH2]
+[Cf]
+[SiH2+]
+[Gd+2]
+[In+]
+[Si@@H]
+[Mn+]
+[99Tc+4]
+[Ga-]
+%67
+[S@+]
+[Ge+4]
+[Tl+3]
+[16OH]
+%68
+[2H-]
+[Ra]
+[si-]
+[NiH2]
+[P@@H]
+[Rh+]
+[12C]
+[35S]
+[32P]
+[SiH2-]
+[AlH2+]
+[16O]
+%69
+[BiH]
+[BiH2]
+[Zn-]
+[BH]
+[Tc+3]
+[Ir+]
+[Ni+]
+%70
+[InH2]
+[InH]
+[Nb+3]
+[PbH]
+[Bi+]
+%71
+[As+3]
+%72
+[18O-]
+[68Ga+3]
+%73
+[Pa]
+[76Br]
+[Tc+5]
+[pH+]
+[64Cu+2]
+[Ru+8]
+%74
+[PH2-]
+[Si+2]
+[17OH]
+[RuH]
+[111In+3]
+[AlH+]
+%75
+%76
+[W+]
+[SbH2]
+[PoH]
+[Ru-]
+[XeH]
+[Tc+2]
+[13C-]
+[Br+]
+[Pt-2]
+[Es]
+[Cu-]
+[Mg+]
+[3HH]
+[P@H]
+[ClH2+]
+%77
+[SH]
+[Au-]
+[2HH]
+%78
+[Sn-]
+[11CH]
+[PdH2]
+0
+[Os+6]
+%79
+[Mo+]
+%80
+[al]
+[PbH2]
+[64Cu]
+[Cl]
+[12CH3]
+%81
+[Tc+7]
+[11c]
+%82
+[Li-]
+[99Tc+5]
+[He]
+[12c]
+[Kr]
+[RuH+2]
+[35Cl]
+[Pd-2]
+[GaH2]
+[4H]
+[Sg]
+[Cu-2]
+[Br+3]
+%83
+[37Cl]
+[211At]
+[IrH+2]
+[Mt]
+[Ir-2]
+[In-]
+[12cH]
+[12CH2]
+[RuH2]
+[99Tc+7]
+%84
+[15n+]
+[ClH2+2]
+[16N]
+[111In]
+[Tc+]
+[Ru-2]
+[12CH]
+[si+]
+[Tc+6]
+%85
+%86
+[90Y]
+[Pd-]
+[188Re]
+[RuH+]
+[NiH]
+[SiH3-]
+[14n]
+[CH3]
+[14N]
+[10BH2]
+%88
+%89
+%90
+[34S]
+[77Br]
+[GaH]
+[Br]
+[Ge@]
+[B@@H-]
+[CuH]
+[SiH4]
+[3H-]
+%87
+%91
+%92
+[67Cu]
+[I]
+[177Lu]
+[ReH]
+[67Ga+3]
+[Db]
+[177Lu+3]
+[AlH2-]
+[Si+3]
+[Ti-2]
+[RuH+3]
+[al+]
+[68Ga]
+[2H+]
+[B@H-]
+[WH2]
+[OsH]
+[Ir-3]
+[AlH-]
+[Bk]
+[75Se]
+[14C@]
+[Pt-]
+[N@@H+]
+[Nb-]
+[13NH2]
+%93
+[186Re]
+[Tb+4]
+[PtH]
+[IrH2]
+[Hg-2]
+[AlH3-]
+[PdH+]
+[Md]
+[RhH+2]
+[11cH]
+[Co-2]
+[15N-]
+[ZrH2]
+%94
+[Hg-]
+[127I]
+[AsH2+]
+[MoH2]
+[Te+4]
+[14C@@]
+[As+5]
+[SnH+3]
+[Ge@@]
+[6Li+]
+[WH]
+[Ne]
+[14NH2]
+[14NH]
+[12C@@H]
+[Os+7]
+[RhH]
+[Al-3]
+[SnH+]
+[15NH3+]
+[Zr+]
+[197Hg+]
+%95
+%96
+[90Y+3]
+[Os-2]
+[98Tc+5]
+[15NH3]
+[bH-]
+[33P]
+[Zr-2]
+[15O]
+[Rh-]
+[PbH3]
+[PH2]
+[Ni-]
+[CuH+]
+%97
+%98
+%99
+[Os+5]
+[PtH+]
+[ReH4]
+[16NH]
+[82Br]
+[W-]
+[18F-]
+[15NH4+]
+[Se+4]
+[SeH-]
+[SH4]
+[67Cu+2]
+[12C@H]
+[AsH3]
+[HgH]
+[10B-]
+[99Tc+6]
+[117Sn+4]
+[Te@]
+[P@+]
+[35SH]
+[SeH+]
+[Ni-2]
+[Al-2]
+[TeH2]
+[Bh]
+[99Tc+2]
+[Os+8]
+[PH-2]
+[7Li+]
+[14nH]
+[AlH+2]
+[18FH]
+[SnH4]
+[18O-2]
+[IrH]
+[13N]
+[Te@@]
+[Rh-3]
+[15NH+]
+[AsH3+]
+[SeH2]
+[AsH+]
+[CoH2]
+[16NH2]
+[AsH-]
+[203Hg+]
+[P@@+]
+[166Ho+3]
+[60Co+3]
+[13CH2-]
+[SeH2+]
+[75Br]
+[TlH2]
+[80Br]
+[siH+]
+[Ca+]
+[153Sm+3]
+[PdH]
+[225Ac]
+[13CH3-]
+[AlH4-]
+[FeH]
+[13CH-]
+[14C-]
+[11C-]
+[153Sm]
+[Re-]
+[te+]
+[13CH4]
+[ClH+2]
+[8CH2]
+[99Mo]
+[ClH3+3]
+[SbH3]
+[25Mg+2]
+[16N+]
+[SnH2+]
+[PH4]
+[11C@H]
+[122I]
+[Re-2]
+[RuH2+2]
+[ZrH]
+[Bi-]
+[Pr+]
+[Rn]
+[Fr]
+[36Cl]
+[18o]
+[YH]
+[79Br]
+[121I]
+[113In+3]
+[InH4-]
+[TaH]
+[RhH2]
+[Ta-]
+[67Ga]
+[ZnH+]
+[SnH2-]
+[OsH2]
+[16F]
+[FeH2]
+[14O]
+[PbH2+2]
+[BH2]
+[6H]
+[125Te]
+[197Hg]
+[TaH2]
+[TaH3]
+[76As]
+[Nb-2]
+[14N+]
+[125I-]
+[33S]
+[IH2+2]
+[NH2]
+[PtH2]
+[MnH]
+[19C]
+[17F]
+[1H-]
+[SnH4+2]
+[Mn-2]
+[15NH2+]
+[TiH2]
+[ReH7]
+[Cd-2]
+[Fe-3]
+[SH2]
+[17O-]
+[siH-]
+[CoH+]
+[VH]
+[10BH]
+[Ru-3]
+[13O]
+[5H]
+[CoH]
+[PH5]
+[15n-]
+[153Gd]
+[12C@]
+[11CH3-]
+[IrH3]
+[RuH3]
+[74Se]
+[Se@]
+[Hf+]
+[77Se]
+[166Ho]
+[59Fe+2]
+[203Hg]
+[18OH-]
+[8CH]
+[12C@@]
+[11CH4]
+[15C]
+[249Cf]
+[PbH4]
+[64Zn]
+[PH3]
+[99Tc+]
+[14c-]
+[149Pm]
+[IrH4]
+[Se@@]
+[13OH]
+[14CH3-]
+[28Si]
+[Rh-2]
+[Fe-2]
+[131I-]
+[51Cr]
+[62Cu+2]
+[81Br]
+[121Sb]
+[7Li]
+[89Zr+4]
+[SbH3+]
+[11C@@H]
+[98Tc]
+[59Fe+3]
+[BiH2+]
+[SbH+]
+[TiH]
+[14NH3]
+[15OH]
+[119Sn]
+[201Hg]
+[MnH+]
+[201Tl]
+[51Cr+3]
+[123I-]
+[MoH]
+[AlH6-3]
+[MnH2]
+[WH3]
+[213Bi+3]
+[SnH2+2]
+[123IH]
+[13CH+]
+[Zr-]
+[74As]
+[13C+]
+[32P+]
+[KrH]
+[SiH+2]
+[ClH3+2]
+[13NH]
+[9CH2]
+[ZrH2+2]
+[87Sr+2]
+[35s]
+[239Pu]
+[198Au]
+[241Am]
+[203Hg+2]
+[V+]
+[YH2]
+[SH5]
+[195Pt]
+[203Pb]
+[RuH4]
+[ThH2]
+[AuH]
+[66Ga+3]
+[11B-]
+[F]
+[24Na+]
+[85Sr+2]
+[201Tl+]
+[14CH4]
+[32S]
+[TeH2+]
+[ClH2+3]
+[AgH]
+[Ge@H]
+[44Ca+2]
+[Os-]
+[31P]
+[15nH+]
+[SbH4]
+[TiH+]
+[Ba+]
+[57Co+2]
+[Ta+]
+[125IH]
+[77As]
+[129I]
+[Fe-4]
+[Ta-2]
+[19O]
+[12O]
+[BiH3]
+[237Np]
+[252Cf]
+[86Y]
+[Cr-2]
+[89Y]
+[195Pt+2]
+[si+2]
+[58Fe+2]
+[Hs]
+[S@@H]
+[OsH6]
+[GdH2]
+[IH3]
+[8CH4]
+[164Dy+3]
+[47Ca+2]
+[57Co]
+[NbH2]
+[ReH2]
+[ZnH2]
+[CrH2]
+[17NH]
+[ZrH3]
+[RhH3]
+[12C-]
+[18O+]
+[Bi-2]
+[ClH4+3]
+[Ni-3]
+[Ag-]
+[111In-]
+[Mo-2]
+[55Fe+3]
+[204Hg+]
+[35Cl-]
+[211Pb]
+[75Ge]
+[8B]
+[TeH3]
+[SnH3+]
+[Zr-3]
+[28F]
+[249Bk]
+[169Yb]
+[34SH]
+[6Li]
+[94Tc]
+[197Au]
+[195Pt+4]
+[169Yb+3]
+[32Cl]
+[82Se]
+[159Gd+3]
+[213Bi]
+[CoH+2]
+[36S]
+[35P]
+[Ru-4]
+[Cr-3]
+[60Co]
+[1H+]
+[18CH2]
+[Cd-]
+[152Sm+3]
+[106Ru]
+[238Pu]
+[220Rn]
+[45Ca+2]
+[89Sr+2]
+[239Np]
+[90Sr+2]
+[137Cs+]
+[165Dy]
+[68GaH3]
+[65Zn+2]
+[89Zr]
+[BiH2+2]
+[62Cu]
+[165Dy+3]
+[238U]
+[105Rh+3]
+[70Zn]
+[12B]
+[12OH]
+[18CH]
+[17CH]
+[OsH3]
+[SbH-]
+[SH6]
+[AlH2-2]
+[42K]
+[76Br-]
+[71As]
+[NbH3]
+[ReH3]
+[OsH-]
+[WH4]
+[MoH3]
+[OsH4]
+[RuH6]
+[PtH3]
+[CuH2]
+[CoH3]
+[TiH4]
+[64Zn+2]
+[Si-2]
+[79BrH]
+[14CH2-]
+[PtH2+2]
+[Os-3]
+[29Si]
+[Ti-]
+[Se+6]
+[22Na+]
+[42K+]
+[131Cs+]
+[86Rb+]
+[134Cs+]
+[209Po]
+[208Po]
+[81Rb+]
+[203Tl+]
+[Zr-4]
+[148Sm]
+[147Sm]
+[37Cl-]
+[12CH4]
+[Ge@@H]
+[63Cu]
+[13CH2+]
+[AsH2-]
+[CeH]
+[SnH-]
+[UH]
+[9c]
+[21CH3]
+[TeH+]
+[57Co+3]
+[8BH2]
+[12BH2]
+[19BH2]
+[9BH2]
+[YbH2]
+[CrH+2]
+[208Bi]
+[152Gd]
+[61Cu]
+[115In]
+[60Co+2]
+[13NH2-]
+[120I]
+[18OH2]
+[75SeH]
+[SbH2+]
+[144Ce]
+[16n]
+[113In]
+[22nH]
+[129I-]
+[InH3]
+[32PH3]
+[234U]
+[235U]
+[59Fe]
+[82Rb+]
+[65Zn]
+[244Cm]
+[147Pm]
+[91Y]
+[237Pu]
+[231Pa]
+[253Cf]
+[127Te]
+[187Re]
+[236Np]
+[235Np]
+[72Zn]
+[253Es]
+[159Dy]
+[62Zn]
+[101Tc]
+[149Tb]
+[124I-]
+[SeH3+]
+[210Pb]
+[40K]
+[210Po]
+[214Pb]
+[218Po]
+[214Po]
+[7Be]
+[212Pb]
+[205Pb]
+[209Pb]
+[123Te]
+[202Pb]
+[72As]
+[201Pb]
+[70As]
+[73Ge]
+[200Pb]
+[198Pb]
+[66Ga]
+[73Se]
+[195Pb]
+[199Pb]
+[144Ce+3]
+[235U+2]
+[90Tc]
+[114In+3]
+[128I]
+[100Tc+]
+[82Br-]
+[191Pt+2]
+[191Pt+4]
+[193Pt+4]
+[31PH3]
+[125I+2]
+[131I+2]
+[125Te+4]
+[82Sr+2]
+[149Sm]
+[81BrH]
+[129Xe]
+[193Pt+2]
+[123I+2]
+[Cr-]
+[Co-]
+[227Th+4]
+[249Cf+3]
+[252Cf+3]
+[187Os]
+[16O-]
+[17O+]
+[16OH-]
+[98Tc+7]
+[58Co+2]
+[69Ga+3]
+[57Fe+2]
+[43K+]
+[16C]
+[52Fe+3]
+[SeH5]
+[194Pb]
+[196Pb]
+[197Pb]
+[213Pb]
+[9B]
+[19B]
+[11CH-]
+[9CH]
+[20OH]
+[25OH]
+[8cH]
+[TiH+3]
+[SnH6+3]
+[N@H+]
+[ZnH]
+[VH3]
+[52Mn+2]
+[64Ga]
+[13B]
+[216Bi]
+[117Sn+2]
+[232Th]
+[SnH+2]
+[BiH5]
+[77Kr]
+[103Cd]
+[62Ni]
+[LaH3]
+[SmH3]
+[EuH3]
+[MoH5]
+[64Ni]
+[66Zn]
+[68Zn]
+[186W]
+[FeH4]
+[MoH4]
+[HgH2]
+[15NH2-]
+[UH2]
+[204Hg]
+[GaH4-]
+[ThH4]
+[WH6]
+[PtH4]
+[VH2]
+[UH3]
+[FeH3]
+[RuH5]
+[BiH4]
+[80Br-]
+[CeH3]
+[37ClH]
+[157Gd+3]
+[205Tl]
+[203Tl]
+[62Cu+]
+[64Cu+]
+[61Cu+]
+[37SH2]
+[30Si]
+[28Al]
+[19OH2]
+[8He]
+[6He]
+[153Pm]
+[209Bi]
+[66Zn+2]
+[10CH4]
+[191Ir]
+[66Cu]
+[16O+]
+[25O]
+[10c]
+[Co-3]
+[Sn@@]
+[17OH-]
+[206Po]
+[204Po]
+[202Po]
+[201Po]
+[200Po]
+[199Po]
+[198Po]
+[197Po]
+[196Po]
+[195Po]
+[194Po]
+[193Po]
+[192Po]
+[191Po]
+[190Po]
+[217Po]
+[BiH4-]
+[TeH4]
+[222Ra]
+[62Ga]
+[39Ar]
+[144Sm]
+[58Fe]
+[153Eu]
+[85Rb]
+[171Yb]
+[172Yb]
+[114Cd]
+[51Fe]
+[142Ce]
+[207Tl]
+[92Mo]
+[115Sn]
+[140Ce]
+[202Hg]
+[180W]
+[182W]
+[183W]
+[184W]
+[96Mo]
+[47Ti]
+[111Cd]
+[143Nd]
+[145Nd]
+[126Te]
+[128Te]
+[130Te]
+[185Re]
+[97Mo]
+[98Mo]
+[183Re]
+[52V]
+[80Se]
+[87Kr]
+[137Xe]
+[196Au]
+[146Ce]
+[88Kr]
+[51Ti]
+[138Xe]
+[112Cd]
+[116Sn]
+[120Sn]
+[28SiH3]
+[35S-]
+[15NH-]
+[13CH3+]
+[34S+]
+[34s]
+[SiH4-]
+[100Tc+5]
+[NiH2+2]
+[239Th]
+[186Lu]
+[AuH3]
+[I@@-]
+[XeH2]
+[B+]
+[16CH2]
+[8C]
+[TaH5]
+[FeH4-]
+[19C@H]
+[10NH]
+[FeH6-3]
+[22CH]
+[25N]
+[25N+]
+[25N-]
+[21CH2]
+[18cH]
+[113I]
+[ScH3]
+[30PH3]
+[43Ca+2]
+[41Ca+2]
+[106Cd]
+[122Sn]
+[18CH3]
+[58Co+3]
+[98Tc+4]
+[70Ge]
+[76Ge]
+[108Cd]
+[116Cd]
+[130Xe]
+[94Mo]
+[124Sn]
+[186Os]
+[188Os]
+[190Os]
+[192Os]
+[106Pd]
+[110Pd]
+[120Te]
+[132Ba]
+[134Ba]
+[136Ba]
+[136Ce]
+[138Ce]
+[156Dy]
+[158Dy]
+[160Dy]
+[163Dy]
+[162Er]
+[164Er]
+[167Er]
+[176Hf]
+[26Mg]
+[144Nd]
+[150Nd]
+[41K]
+[46Ti]
+[48Ti]
+[49Ti]
+[50Ti]
+[170Yb]
+[173Yb]
+[91Zr]
+[92Zr]
+[96Zr]
+[34S-]
+[CuH2-]
+[38Cl]
+[25Mg]
+[51V]
+[93Nb]
+[95Mo]
+[45Sc]
+[123Sb]
+[139La]
+[9Be]
+[99Y+3]
+[99Y]
+[156Ho]
+[67Zn]
+[144Ce+4]
+[210Tl]
+[42Ca]
+[54Fe]
+[193Ir]
+[92Nb]
+[141Cs]
+[52Cr]
+[35ClH]
+[46Ca]
+[139Cs]
+[65Cu]
+[71Ga]
+[60Ni]
+[16NH3]
+[148Nd]
+[72Ge]
+[161Dy]
+[49Ca]
+[43Ca]
+[8Be]
+[48Ca]
+[44Ca]
+[120Xe]
+[80Rb]
+[215At]
+[180Re]
+[146Sm]
+[19Ne]
+[74Kr]
+[134La]
+[76Kr]
+[219Fr]
+[121Xe]
+[220Fr]
+[216At]
+[223Ac]
+[218At]
+[37Ar]
+[135I]
+[110Cd]
+[94Tc+7]
+[86Y+3]
+[135I-]
+[15O-2]
+[151Eu+3]
+[161Tb+3]
+[197Hg+2]
+[109Cd+2]
+[191Os+4]
+[170Tm+3]
+[205Bi+3]
+[233U+4]
+[126Sb+3]
+[127Sb+3]
+[132Cs+]
+[136Eu+3]
+[136Eu]
+[125Sn+4]
+[175Yb+3]
+[100Mo]
+[22Ne]
+[13c-]
+[13NH4+]
+[17C]
+[9C]
+[31S]
+[31SH]
+[133I]
+[126I]
+[36SH]
+[30S]
+[32SH]
+[19CH2]
+[19c]
+[18c]
+[15F]
+[10C]
+[RuH-]
+[62Zn+2]
+[32ClH]
+[33ClH]
+[78BrH]
+[12Li+]
+[12Li]
+[233Ra]
+[68Ge+4]
+[44Sc+3]
+[91Y+3]
+[106Ru+3]
+[PoH2]
+[AtH]
+[55Fe]
+[233U]
+[210PoH2]
+[230Th]
+[228Th]
+[222Rn]
+[35SH2]
+[227Th]
+[192Ir]
+[133Xe]
+[81Kr]
+[95Zr]
+[240Pu]
+[54Mn]
+[103Ru]
+[95Nb]
+[109Cd]
+[141Ce]
+[85Kr]
+[110Ag]
+[58Co]
+[241Pu]
+[234Th]
+[140La]
+[63Ni]
+[152Eu]
+[132IH]
+[226Rn]
+[154Eu]
+[36ClH]
+[228Ac]
+[155Eu]
+[106Rh]
+[243Am]
+[227Ac]
+[243Cm]
+[236U]
+[144Pr]
+[232U]
+[32SH2]
+[88Y]
+[82BrH]
+[135IH]
+[242Cm]
+[115Cd]
+[242Pu]
+[46Sc]
+[56Mn]
+[234Pa]
+[41Ar]
+[147Nd]
+[187W]
+[151Sm]
+[59Ni]
+[233Pa]
+[52Mn]
+[94Nb]
+[219Rn]
+[236Pu]
+[13NH3]
+[93Zr]
+[51Cr+6]
+[TlH3]
+[123Xe]
+[160Tb]
+[170Tm]
+[182Ta]
+[175Yb]
+[93Mo]
+[143Ce]
+[191Os]
+[126IH]
+[48V]
+[113Cd]
+[47Sc]
+[181Hf]
+[185W]
+[143Pr]
+[191Pt]
+[181W]
+[33PH3]
+[97Ru]
+[97Tc]
+[111Ag]
+[169Er]
+[107Pd]
+[103Ru+2]
+[34SH2]
+[137Ce]
+[242Am]
+[117SnH2]
+[57Ni]
+[239U]
+[60Cu]
+[250Cf]
+[193Au]
+[69Zn]
+[55Co]
+[139Ce]
+[127Xe]
+[159Gd]
+[56Co]
+[177Hf]
+[244Pu]
+[38ClH]
+[142Pr]
+[199Hg]
+[179Hf]
+[178Hf]
+[237U]
+[156Eu]
+[157Eu]
+[105Ru]
+[171Tm]
+[199Au]
+[155Sm]
+[80BrH]
+[108Ag]
+[128IH]
+[48Sc]
+[45Ti]
+[176Lu]
+[121SnH2]
+[148Pm]
+[57Fe]
+[10BH3]
+[96Tc]
+[133IH]
+[143Pm]
+[105Rh]
+[130IH]
+[134IH]
+[131IH]
+[71Zn]
+[105Ag]
+[97Zr]
+[235Pu]
+[231Th]
+[109Pd]
+[93Y]
+[190Ir]
+[135Xe]
+[53Mn]
+[134Ce]
+[234Np]
+[240Am]
+[246Cf]
+[240Cm]
+[241Cm]
+[226Th]
+[39ClH]
+[229Th]
+[245Cm]
+[240U]
+[240Np]
+[249Cm]
+[243Pu]
+[145Pm]
+[199Pt]
+[246Bk]
+[193Pt]
+[230U]
+[250Cm]
+[44Ti]
+[175Hf]
+[254Fm]
+[255Fm]
+[257Fm]
+[92Y]
+[188Ir]
+[171Lu]
+[257Md]
+[247Bk]
+[121IH]
+[250Bk]
+[179Lu]
+[224Ac]
+[195Hg]
+[244Am]
+[246Pu]
+[194Au]
+[252Fm]
+[173Hf]
+[246Cm]
+[135Ce]
+[49Cr]
+[248Cf]
+[247Cm]
+[248Cm]
+[174Ta]
+[176Ta]
+[154Tb]
+[172Ta]
+[177Ta]
+[175Ta]
+[180Ta]
+[158Tb]
+[115Ag]
+[189Os]
+[251Cf]
+[145Pr]
+[147Pr]
+[76BrH]
+[102Rh]
+[238Np]
+[185Os]
+[246Am]
+[233Np]
+[166Dy]
+[254Es]
+[244Cf]
+[193Os]
+[245Am]
+[245Bk]
+[239Am]
+[238Am]
+[97Nb]
+[245Pu]
+[254Cf]
+[188W]
+[250Es]
+[251Es]
+[237Am]
+[182Hf]
+[258Md]
+[232Np]
+[238Cm]
+[60Fe]
+[109Pd+2]
+[234Pu]
+[141Ce+3]
+[136Nd]
+[136Pr]
+[173Ta]
+[110Ru]
+[147Tb]
+[253Fm]
+[139Nd]
+[178Re]
+[177Re]
+[200Au]
+[182Re]
+[156Tb]
+[155Tb]
+[157Tb]
+[161Tb]
+[161Ho]
+[167Tm]
+[173Lu]
+[179Ta]
+[171Er]
+[44Sc]
+[49Sc]
+[49V]
+[51Mn]
+[90Nb]
+[88Nb]
+[88Zr]
+[36SH2]
+[174Yb]
+[178Lu]
+[179W]
+[83BrH]
+[107Cd]
+[75BrH]
+[62Co]
+[48Cr]
+[63Zn]
+[102Ag]
+[154Sm]
+[168Er]
+[65Ni]
+[137La]
+[187Ir]
+[144Pm]
+[146Pm]
+[160Gd]
+[166Yb]
+[162Dy]
+[47V]
+[141Nd]
+[141Sm]
+[166Er]
+[150Sm]
+[146Eu]
+[149Eu]
+[174Lu]
+[17NH3]
+[102Ru]
+[170Hf]
+[188Pt]
+[61Ni]
+[56Ni]
+[149Gd]
+[151Gd]
+[141Pm]
+[147Gd]
+[146Gd]
+[161Er]
+[103Ag]
+[145Eu]
+[153Tb]
+[155Dy]
+[184Re]
+[180Os]
+[182Os]
+[186Pt]
+[181Os]
+[181Re]
+[151Tb]
+[178Ta]
+[178W]
+[189Pt]
+[194Hg]
+[145Sm]
+[150Tb]
+[132La]
+[158Gd]
+[104Ag]
+[193Hg]
+[94Ru]
+[137Pr]
+[155Ho]
+[117Cd]
+[99Ru]
+[146Nd]
+[218Rn]
+[95Y]
+[79Kr]
+[120IH]
+[138Pr]
+[100Pd]
+[166Tm]
+[90Mo]
+[151Nd]
+[231U]
+[138Nd]
+[89Nb]
+[98Nb]
+[162Ho]
+[142Sm]
+[186Ta]
+[104Tc]
+[184Ta]
+[185Ta]
+[170Er]
+[107Rh]
+[131La]
+[169Lu]
+[74BrH]
+[150Pm]
+[172Tm]
+[197Pt]
+[230Pu]
+[170Lu]
+[86Zr]
+[176W]
+[177W]
+[101Pd]
+[105Pd]
+[108Pd]
+[149Nd]
+[164Ho]
+[159Ho]
+[167Ho]
+[176Yb]
+[156Sm]
+[77BrH]
+[189Re]
+[99Rh]
+[100Rh]
+[151Pm]
+[232Pa]
+[228Pa]
+[230Pa]
+[66Ni]
+[194Os]
+[135La]
+[138La]
+[141La]
+[142La]
+[195Ir]
+[96Nb]
+[157Ho]
+[183Hf]
+[162Tm]
+[172Er]
+[148Eu]
+[150Eu]
+[15CH4]
+[89Kr]
+[143La]
+[58Ni]
+[61Co]
+[158Eu]
+[165Er]
+[167Yb]
+[173Tm]
+[175Tm]
+[172Hf]
+[172Lu]
+[93Tc]
+[177Yb]
+[124IH]
+[194Ir]
+[147Eu]
+[101Mo]
+[180Hf]
+[189Ir]
+[87Y]
+[43Sc]
+[195Au]
+[112Ag]
+[84BrH]
+[106Ag]
+[109Ag]
+[101Rh]
+[162Yb]
+[228Rn]
+[139Pr]
+[94Y]
+[201Au]
+[40PH3]
+[110Ag+]
+[104Cd]
+[133Ba+2]
+[226Ac]
+[145Gd]
+[186Ir]
+[184Ir]
+[224Rn]
+[185Ir]
+[182Ir]
+[184Hf]
+[200Pt]
+[227Pa]
+[178Yb]
+[72Br-]
+[72BrH]
+[248Am]
+[238Th]
+[161Gd]
+[35S-2]
+[107Ag]
+[FeH6-4]
+[89Sr]
+[SnH3-]
+[SeH3]
+[TeH3+]
+[SbH4+]
+[AsH4+]
+[4He]
+[AsH3-]
+[1HH]
+[3H+]
+[82Rb]
+[85Sr]
+[90Sr]
+[137Cs]
+[133Ba]
+[131Cs]
+[SbH5]
+[224Ra]
+[22Na]
+[210Bi]
+[214Bi]
+[228Ra]
+[127Sb]
+[136Cs]
+[125Sb]
+[134Cs]
+[140Ba]
+[45Ca]
+[206Pb]
+[207Pb]
+[24Na]
+[86Rb]
+[212Bi]
+[208Pb]
+[124Sb]
+[204Pb]
+[44K]
+[129Te]
+[113Sn]
+[204Tl]
+[87Sr]
+[208Tl]
+[87Rb]
+[47Ca]
+[135Cs]
+[216Po]
+[137Ba]
+[207Bi]
+[212Po]
+[79Se]
+[223Ra]
+[86Sr]
+[122Sb]
+[26Al]
+[32Si]
+[126Sn]
+[225Ra]
+[114In]
+[72Ga]
+[132Te]
+[10Be]
+[125Sn]
+[73As]
+[206Bi]
+[117Sn]
+[40Ca]
+[41Ca]
+[89Rb]
+[116In]
+[129Sb]
+[91Sr]
+[71Ge]
+[139Ba]
+[69Ga]
+[120Sb]
+[121Sn]
+[123Sn]
+[131Te]
+[77Ge]
+[135Ba]
+[82Sr]
+[43K]
+[131Ba]
+[92Sr]
+[88Rb]
+[129Cs]
+[144Cs]
+[127Cs]
+[200Tl]
+[202Tl]
+[141Ba]
+[117Sb]
+[116Sb]
+[78As]
+[131Sb]
+[126Sb]
+[128Sb]
+[130Sb]
+[67Ge]
+[68Ge]
+[78Ge]
+[66Ge]
+[223Fr]
+[132Cs]
+[125Cs]
+[138Cs]
+[133Te]
+[84Rb]
+[83Rb]
+[81Rb]
+[142Ba]
+[200Bi]
+[115Sb]
+[194Tl]
+[70Se]
+[112In]
+[118Sb]
+[70Ga]
+[27Mg]
+[202Bi]
+[83Se]
+[9Li]
+[69As]
+[79Rb]
+[81Sr]
+[83Sr]
+[78Se]
+[109In]
+[29Al]
+[118Sn]
+[117In]
+[119Sb]
+[114Sn]
+[138Ba]
+[69Ge]
+[73Ga]
+[74Ge]
+[206Tl]
+[199Tl]
+[130Cs]
+[28Mg]
+[116Te]
+[112Sn]
+[126Ba]
+[211Bi]
+[81Se]
+[127Sn]
+[143Cs]
+[134Te]
+[80Sr]
+[45K]
+[215Po]
+[207Po]
+[111Sn]
+[211Po]
+[128Ba]
+[198Tl]
+[227Ra]
+[213Po]
+[220Ra]
+[128Sn]
+[203Po]
+[205Po]
+[65Ga]
+[197Tl]
+[88Sr]
+[110In]
+[31Si]
+[201Bi]
+[121Te]
+[205Bi]
+[203Bi]
+[195Tl]
+[209Tl]
+[110Sn]
+[222Fr]
+[207At]
+[119In]
+[As@]
+[129IH]
+[157Dy]
+[111IH]
+[230Ra]
+[144Pr+3]
+[SiH3+]
+[3He]
+[AsH5]
+[72Se]
+[95Tc]
+[103Pd]
+[121Sn+2]
+[211Rn]
+[38SH2]
+[127IH]
+[74Br-]
+[133I-]
+[100Tc+4]
+[100Tc]
+[36Cl-]
+[89Y+3]
+[104Rh]
+[152Sm]
+[226Ra]
+[19FH]
+[104Pd]
+[148Gd]
+[157Lu]
+[33SH2]
+[121I-]
+[17FH]
+[71Se]
+[157Sm]
+[148Tb]
+[164Dy]
+[15OH2]
+[15O+]
+[39K]
+[40Ar]
+[50Cr+3]
+[50Cr]
+[52Ti]
+[103Pd+2]
+[130Ba]
+[142Pm]
+[153Gd+3]
+[151Eu]
+[103Rh]
+[124Xe]
+[152Tb]
+[17OH2]
+[20Ne]
+[52Fe]
+[94Zr+4]
+[94Zr]
+[149Pr]
+[16OH2]
+[53Cr+6]
+[53Cr]
+[81Br-]
+[112Pd]
+[125Xe]
+[155Gd]
+[157Gd]
+[168Yb]
+[184Os]
+[166Tb]
+[221Fr]
+[212Ra]
+[75Br-]
+[79Br-]
+[113Ag]
+[23Na]
+[34Cl-]
+[34ClH]
+[38Cl-]
+[56Fe]
+[68Cu]
+[77Br-]
+[90Zr+4]
+[90Zr]
+[102Pd]
+[154Eu+3]
+[57Mn]
+[165Tm]
+[152Dy]
+[217At]
+[77se]
+[13cH-]
+[122Te]
+[156Gd]
+[124Te]
+[53Ni]
+[131Xe]
+[174Hf+4]
+[174Hf]
+[76Se]
+[168Tm]
+[167Dy]
+[154Gd]
+[95Ru]
+[210At]
+[85Br]
+[59Co]
+[122Xe]
+[27Al]
+[54Cr]
+[198Hg]
+[85Rb+]
+[214Tl]
+[229Rn]
+[218Pb]
+[218Bi]
+[167Tm+3]
+[18o+]
+[P@@H+]
+[P@H+]
+[13N+]
+[212Pb+2]
+[217Bi]
+[249Cf+2]
+[18OH3+]
+[90Sr-]
+[Cf+3]
+[200Hg]
+[86Tc]
+[141Pr+3]
+[141Pr]
+[16nH]
+[14NH4+]
+[132Xe]
+[83Kr]
+[70Zn+2]
+[137Ba+2]
+[36Ar]
+[38Ar]
+[21Ne]
+[126Xe]
+[136Xe]
+[128Xe]
+[134Xe]
+[84Kr]
+[86Kr]
+[78Kr]
+[80Kr]
+[82Kr]
+[67Zn+2]
+[65Cu+2]
+[110Te]
+[58Fe+3]
+[142Nd]
+[38K]
+[198Au+3]
+[122IH]
+[38PH3]
+[130I-]
+[40K+]
+[38K+]
+[28Mg+2]
+[208Tl+]
+[13OH2]
+[198Bi]
+[192Bi]
+[194Bi]
+[196Bi]
+[132I-]
+[83Sr+2]
+[169Er+3]
+[122I-]
+[120I-]
+[92Sr+2]
+[126I-]
+[24Mg]
+[84Sr]
+[118Pd+2]
+[118Pd]
+[AsH4]
+[127I-]
+[9C-]
+[11CH3+]
+[17B]
+[7B]
+[4HH]
+[18C-]
+[22CH3-]
+[22CH4]
+[17C-]
+[15CH3]
+[16CH3]
+[11NH3]
+[21NH3]
+[11N-]
+[11NH]
+[16CH]
+[17CH2]
+[99Ru+2]
+[181Ta+2]
+[181Ta]
+[20CH]
+[32PH2]
+[55Fe+2]
+[SH3]
+[S@H]
+[Mn-]
+[IH4]
+[ThH]
+[GaH-]
+[BiH+]
+[EuH2]
+[FeH4-3]
+[FeH6]
+[IH5]
+[NiH+]
+[SrH2]
+[VH4]
+[YH3]
+[seH+]
+
diff --git a/models/smi_ted/inference/smi_ted_large/load.py b/models/smi_ted/inference/smi_ted_large/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ebe464b24cb7ac04223ee61528ddfea4a216f54
--- /dev/null
+++ b/models/smi_ted/inference/smi_ted_large/load.py
@@ -0,0 +1,672 @@
+PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+
+# Transformers
+from fast_transformers.attention import AttentionLayer
+from fast_transformers.events import QKVEvent
+from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer
+from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder
+from fast_transformers.builders.attention_builders import AttentionBuilder
+from fast_transformers.feature_maps import GeneralizedRandomFeatures
+from fast_transformers.masking import LengthMask
+from transformers import BertTokenizer
+
+# Data
+import numpy as np
+import pandas as pd
+
+# Chemistry
+from rdkit import Chem
+from rdkit.Chem import PandasTools
+from rdkit.Chem import Descriptors
+PandasTools.RenderImagesInAllDataFrames(True)
+
+# Standard library
+from functools import partial
+import regex as re
+import random
+import os
+import gc
+from tqdm import tqdm
+tqdm.pandas()
+
+
+# function to canonicalize SMILES
+def normalize_smiles(smi, canonical=True, isomeric=False):
+ try:
+ normalized = Chem.MolToSmiles(
+ Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
+ )
+ except:
+ normalized = None
+ return normalized
+
+
+class MolTranBertTokenizer(BertTokenizer):
+ def __init__(self, vocab_file: str = '',
+ do_lower_case=False,
+ unk_token='',
+ sep_token='',
+ pad_token='',
+ cls_token='',
+ mask_token='',
+ **kwargs):
+ super().__init__(vocab_file,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs)
+
+ self.regex_tokenizer = re.compile(PATTERN)
+ self.wordpiece_tokenizer = None
+ self.basic_tokenizer = None
+ with open(vocab_file) as f:
+ self.padding_idx = f.readlines().index(pad_token+'\n')
+
+ def _tokenize(self, text):
+ split_tokens = self.regex_tokenizer.findall(text)
+ return split_tokens
+
+ def convert_idx_to_tokens(self, idx_tensor):
+ tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
+ return tokens
+
+ def convert_tokens_to_string(self, tokens):
+ stopwords = ['', '']
+ clean_tokens = [word for word in tokens if word not in stopwords]
+ out_string = ''.join(clean_tokens)
+ return out_string
+
+ def get_padding_idx(self):
+ return self.padding_idx
+
+ def idx_to_smiles(self, torch_model, idx):
+ '''Convert tokens idx back to SMILES text'''
+ rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
+ flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
+ decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
+ return decoded_smiles
+
+
+## Transformer layers
+class RotaryEmbedding(torch.nn.Module):
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = 0
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self.cos_cached = emb.cos()[None,:, None, :]
+ self.sin_cached = emb.sin()[None,:, None, :]
+
+ return self.cos_cached, self.sin_cached
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+@torch.jit.script
+def apply_rotary_pos_emb(q, k, cos, sin):
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+class RotateAttentionLayer(AttentionLayer):
+ """Rotate attention layer inherits from fast_transformer attention layer.
+ The only thing added is an Embedding encoding, for more information
+ on the attention layer see the fast_transformers code
+ """
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
+ d_values=None, event_dispatcher=""):
+ super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys,
+ d_values=d_values, event_dispatcher=event_dispatcher)
+
+ self.rotaryemb = RotaryEmbedding(d_keys)
+ print('Using Rotation Embedding')
+
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
+ key_lengths):
+ """
+ Using the same frame work as the fast_Transformers attention layer
+ but injecting rotary information to the queries and the keys
+ after the keys and queries are projected.
+ In the argument description we make use of the following sizes
+ - N: the batch size
+ - L: The maximum length of the queries
+ - S: The maximum length of the keys (the actual length per sequence
+ is given by the length mask)
+ - D: The input feature dimensionality passed in the constructor as
+ 'd_model'
+ Arguments
+ ---------
+ queries: (N, L, D) The tensor containing the queries
+ keys: (N, S, D) The tensor containing the keys
+ values: (N, S, D) The tensor containing the values
+ attn_mask: An implementation of BaseMask that encodes where each
+ query can attend to
+ query_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ key_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ Returns
+ -------
+ The new value for each query as a tensor of shape (N, L, D).
+ """
+ # Extract the dimensions into local variables
+ N, L, _ = queries.shape
+ _, S, _ = keys.shape
+ H = self.n_heads
+
+ # Project the queries/keys/values
+ queries = self.query_projection(queries).view(N, L, H, -1)
+ keys = self.key_projection(keys).view(N, S, H, -1)
+ cos, sin = self.rotaryemb(queries)
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
+ values = self.value_projection(values).view(N, S, H, -1)
+ # Let the world know of the qkv
+ self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
+
+
+ # Compute the attention
+ new_values = self.inner_attention(
+ queries,
+ keys,
+ values,
+ attn_mask,
+ query_lengths,
+ key_lengths
+ ).view(N, L, -1)
+
+ # Project the output and return
+ return self.out_projection(new_values)
+
+class RotateEncoderBuilder(BaseTransformerEncoderBuilder):
+ """Build a batch transformer encoder with Relative Rotary embeddings
+ for training or processing of sequences all elements at a time.
+ Example usage:
+ builder = RotateEncoderBuilder()
+ builder.n_layers = 12
+ builder.n_heads = 8
+ builder.feed_forward_dimensions = 1024
+ builder.query_dimensions = 64
+ builder.value_dimensions = 64
+ builder.dropout = 0.1
+ builder.attention_dropout = 0.1
+ builder.attention_type = "linear"
+ transformer = builder.get()
+ """
+ def _get_attention_builder(self):
+ """Return an instance of the appropriate attention builder."""
+ return AttentionBuilder()
+
+ def _get_attention_layer_class(self):
+ """Return the class for the layer that projects queries keys and
+ values."""
+ return RotateAttentionLayer
+
+ def _get_encoder_class(self):
+ """Return the class for the transformer encoder."""
+ return TransformerEncoder
+
+ def _get_encoder_layer_class(self):
+ """Return the class for the transformer encoder layer."""
+ return TransformerEncoderLayer
+
+
+class AutoEncoderLayer(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.encoder = self.Encoder(feature_size, latent_size)
+ self.decoder = self.Decoder(feature_size, latent_size)
+
+ class Encoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(feature_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.lat = nn.Linear(latent_size, latent_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.lat.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.lat(x)
+ return x # -> (N, D)
+
+ class Decoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(latent_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.rec = nn.Linear(latent_size, feature_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.rec.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.rec(x)
+ return x # -> (N, L*D)
+
+
+class LangLayer(nn.Module):
+
+ def __init__(self, n_embd, n_vocab):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.embed = nn.Linear(n_embd, n_embd)
+ self.ln_f = nn.LayerNorm(n_embd)
+ self.head = nn.Linear(n_embd, n_vocab, bias=False)
+
+ def forward(self, tensor):
+ if self.is_cuda_available:
+ self.embed.cuda()
+ self.ln_f.cuda()
+ self.head.cuda()
+ tensor = tensor.cuda()
+ tensor = self.embed(tensor)
+ tensor = F.gelu(tensor)
+ tensor = self.ln_f(tensor)
+ tensor = self.head(tensor)
+ return tensor
+
+
+class Net(nn.Module):
+
+ def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2):
+ super().__init__()
+ self.desc_skip_connection = True
+ self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.relu1 = nn.GELU()
+ self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout2 = nn.Dropout(dropout)
+ self.relu2 = nn.GELU()
+ self.final = nn.Linear(smiles_embed_dim, n_output)
+
+ def forward(self, smiles_emb, multitask=False):
+ x_out = self.fc1(smiles_emb)
+ x_out = self.dropout1(x_out)
+ x_out = self.relu1(x_out)
+
+ if self.desc_skip_connection is True:
+ x_out = x_out + smiles_emb
+
+ z = self.fc2(x_out)
+ z = self.dropout2(z)
+ z = self.relu2(z)
+ if self.desc_skip_connection is True:
+ z = self.final(z + x_out)
+ else:
+ z = self.final(z)
+
+ if multitask:
+ return F.sigmoid(z)
+ return z
+
+
+class MoLEncoder(nn.Module):
+
+ def __init__(self, config, n_vocab):
+ super(MoLEncoder, self).__init__()
+
+ # embeddings
+ self.config = config
+ self.tok_emb = nn.Embedding(n_vocab, config['n_embd'])
+ self.drop = nn.Dropout(config['d_dropout'])
+
+ # transformer
+ builder = RotateEncoderBuilder.from_kwargs(
+ n_layers=config['n_layer'],
+ n_heads=config['n_head'],
+ query_dimensions=config['n_embd']//config['n_head'],
+ value_dimensions=config['n_embd']//config['n_head'],
+ feed_forward_dimensions=None,
+ attention_type='linear',
+ # unless we do deterministic_eval here, we will have random outputs
+ feature_map=partial(GeneralizedRandomFeatures,
+ n_dims=config['num_feats'],
+ deterministic_eval=True),
+ activation='gelu'
+ )
+ self.blocks = builder.get()
+
+ # classification
+ self.lang_model = LangLayer(config['n_embd'], n_vocab)
+
+ def forward(self, idx, mask):
+ # transformer encoder
+ x = self.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.drop(x)
+ x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1]))
+
+ # add padding
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ mask_embeddings = (token_embeddings * input_mask_expanded)
+ token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0)
+
+ return token_embeddings
+
+
+class MoLDecoder(nn.Module):
+
+ def __init__(self, n_vocab, max_len, n_embd, n_gpu=None):
+ super(MoLDecoder, self).__init__()
+
+ self.max_len = max_len
+ self.n_embd = n_embd
+ self.n_gpu = n_gpu
+ self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd)
+ self.lang_model = LangLayer(n_embd, n_vocab)
+
+
+class Smi_ted(nn.Module):
+ """materials.smi-ted-Large 738M Parameters"""
+
+ def __init__(self, tokenizer, config=None):
+ super(Smi_ted, self).__init__()
+
+ # configuration
+ self.config = config
+ self.tokenizer = tokenizer
+ self.padding_idx = tokenizer.get_padding_idx()
+ self.n_vocab = len(self.tokenizer.vocab)
+ self.is_cuda_available = torch.cuda.is_available()
+
+ # instantiate modules
+ if self.config:
+ self.encoder = MoLEncoder(self.config, self.n_vocab)
+ self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
+ self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['d_dropout'])
+
+ def load_checkpoint(self, ckpt_path):
+ # load checkpoint file
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
+
+ # load hyparameters
+ self.config = checkpoint['hparams']
+ self.max_len = self.config['max_len']
+ self.n_embd = self.config['n_embd']
+ self._set_seed(self.config['seed'])
+
+ # instantiate modules
+ self.encoder = MoLEncoder(self.config, self.n_vocab)
+ self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
+ self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else 1, dropout=self.config['d_dropout'])
+
+ # load weights
+ if 'state_dict' in checkpoint:
+ if isinstance(checkpoint['state_dict'], list):
+ self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False)
+ self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False)
+ else:
+ self.load_state_dict(checkpoint['state_dict'], strict=False)
+ elif 'MODEL_STATE' in checkpoint:
+ self.load_state_dict(checkpoint['MODEL_STATE'], strict=False)
+
+ # load RNG states each time the model and states are loaded from checkpoint
+ if 'rng' in self.config:
+ rng = self.config['rng']
+ for key, value in rng.items():
+ if key =='torch_state':
+ torch.set_rng_state(value.cpu())
+ elif key =='cuda_state':
+ torch.cuda.set_rng_state(value.cpu())
+ elif key =='numpy_state':
+ np.random.set_state(value)
+ elif key =='python_state':
+ random.setstate(value)
+ else:
+ print('unrecognized state')
+
+ def _set_seed(self, value):
+ print('Random Seed:', value)
+ random.seed(value)
+ torch.manual_seed(value)
+ torch.cuda.manual_seed(value)
+ torch.cuda.manual_seed_all(value)
+ np.random.seed(value)
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+
+ def forward(self, smiles, batch_size=100):
+ return self.decode(self.encode(smiles, batch_size=batch_size, return_torch=True))
+
+ def tokenize(self, smiles):
+ """Tokenize a string into tokens."""
+ if isinstance(smiles, str):
+ batch = [smiles]
+ else:
+ batch = smiles
+
+ tokens = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ max_length=self.max_len,
+ )
+
+ idx = tokens['input_ids'].clone().detach()
+ mask = tokens['attention_mask'].clone().detach()
+
+ if self.is_cuda_available:
+ return idx.cuda(), mask.cuda()
+
+ return idx, mask
+
+ def extract_all(self, smiles):
+ """Extract all elements from each part of smi-ted. Be careful."""
+ # evaluation mode
+ self.encoder.eval()
+ self.decoder.eval()
+ if self.is_cuda_available:
+ self.encoder.cuda()
+ self.decoder.cuda()
+
+ # handle single str or a list of str
+ smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
+
+ # SMILES normalization
+ smiles = smiles.apply(normalize_smiles)
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
+ smiles = smiles.dropna()
+
+ # tokenizer
+ idx, mask = self.tokenize(smiles.to_list())
+
+ ###########
+ # Encoder #
+ ###########
+ # encoder forward
+ x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.encoder.drop(x)
+ x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1)))
+
+ # mean pooling
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+ true_set = sum_embeddings / sum_mask # DO NOT USE THIS FOR DOWNSTREAM TASKS, USE `pred_set` INSTEAD
+
+ # add padding
+ mask_embeddings = (token_embeddings * input_mask_expanded)
+ token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.max_len - mask_embeddings.shape[1]), value=0)
+ idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=2)
+
+ true_ids = idx
+ true_cte = token_embeddings
+ true_cte = true_cte.view(-1, self.max_len*self.n_embd)
+
+ ###########
+ # Decoder #
+ ###########
+ # CTE autoencoder
+ pred_set = self.decoder.autoencoder.encoder(true_cte)
+ pred_cte = self.decoder.autoencoder.decoder(pred_set)
+
+ # reconstruct tokens
+ pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
+ pred_ids = torch.argmax(pred_ids, axis=-1)
+
+ # replacing null SMILES with NaN values
+ for idx in null_idx:
+ true_ids = true_ids.tolist()
+ pred_ids = pred_ids.tolist()
+ true_cte = true_cte.tolist()
+ pred_cte = pred_cte.tolist()
+ true_set = true_set.tolist()
+ pred_set = pred_set.tolist()
+
+ true_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
+ pred_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
+ true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
+ pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
+ true_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
+ pred_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
+
+ if len(null_idx) > 0:
+ true_ids = torch.tensor(true_ids)
+ pred_ids = torch.tensor(pred_ids)
+ true_cte = torch.tensor(true_cte)
+ pred_cte = torch.tensor(pred_cte)
+ true_set = torch.tensor(true_set)
+ pred_set = torch.tensor(pred_set)
+
+ return ((true_ids, pred_ids), # tokens
+ (true_cte, pred_cte), # token embeddings
+ (true_set, pred_set)) # smiles embeddings
+
+ def extract_embeddings(self, smiles):
+ """Extract token and SMILES embeddings."""
+ # evaluation mode
+ self.encoder.eval()
+ if self.is_cuda_available:
+ self.encoder.cuda()
+
+ # tokenizer
+ idx, mask = self.tokenize(smiles)
+
+ # encoder forward
+ token_embeddings = self.encoder(idx, mask)
+
+ # aggregate token embeddings (similar to mean pooling)
+ # CAUTION: use the embeddings from the autoencoder.
+ smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd))
+
+ # add padding
+ idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=self.padding_idx)
+
+ return idx, token_embeddings, smiles_embeddings
+
+ def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False):
+ """Extract efficiently SMILES embeddings per batches."""
+ # TODO: remove useCuda argument
+
+ # handle single str or a list of str
+ smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
+
+ # SMILES normalization
+ smiles = smiles.apply(normalize_smiles)
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
+ smiles = smiles.dropna()
+
+ # process in batches
+ n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
+ embeddings = [
+ self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
+ for batch in tqdm(np.array_split(smiles, n_split))
+ ]
+ flat_list = [item for sublist in embeddings for item in sublist]
+
+ # clear GPU memory
+ if self.is_cuda_available:
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # replacing null SMILES with NaN values
+ for idx in null_idx:
+ flat_list.insert(idx, np.array([np.nan]*self.config['n_embd']))
+ flat_list = np.asarray(flat_list)
+
+ if return_torch:
+ return torch.tensor(flat_list)
+ return pd.DataFrame(flat_list)
+
+ def decode(self, smiles_embeddings):
+ """Decode SMILES embeddings back to SMILES."""
+ # evaluation mode
+ self.decoder.eval()
+ if self.is_cuda_available:
+ self.decoder.cuda()
+
+ # reconstruct token embeddings
+ pred_token_embds = self.decoder.autoencoder.decoder(smiles_embeddings)
+
+ # reconstruct tokens
+ pred_idx = self.decoder.lang_model(pred_token_embds.view(-1, self.max_len, self.n_embd))
+ pred_idx = torch.argmax(pred_idx, axis=-1).cpu().detach().numpy()
+
+ # convert idx to tokens
+ pred_smiles = []
+ for i in range(pred_idx.shape[0]):
+ idx = pred_idx[i]
+ smiles = self.tokenizer.idx_to_smiles(self, idx)
+ smiles = smiles.replace('', '') # begin token
+ smiles = smiles.replace('', '') # end token
+ smiles = smiles.replace('', '') # pad token
+ pred_smiles.append(smiles)
+
+ # clear GPU memory
+ if self.is_cuda_available:
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ return pred_smiles
+
+ def __str__(self):
+ return 'smi-ted-Large'
+
+
+def load_smi_ted(folder="./smi_ted_large",
+ ckpt_filename="smi-ted-Large_30.pt",
+ vocab_filename="bert_vocab_curated.txt"
+ ):
+ tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
+ model = Smi_ted(tokenizer)
+ model.load_checkpoint(os.path.join(folder, ckpt_filename))
+ model.eval()
+ print('Vocab size:', len(tokenizer.vocab))
+ print(f'[INFERENCE MODE - {str(model)}]')
+ return model
\ No newline at end of file
diff --git a/models/smi_ted/inference/smi_ted_light/bert_vocab_curated.txt b/models/smi_ted/inference/smi_ted_light/bert_vocab_curated.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd
--- /dev/null
+++ b/models/smi_ted/inference/smi_ted_light/bert_vocab_curated.txt
@@ -0,0 +1,2393 @@
+
+
+
+
+C
+c
+(
+)
+1
+O
+N
+2
+=
+n
+3
+[C@H]
+[C@@H]
+F
+S
+4
+Cl
+-
+o
+s
+[nH]
+#
+/
+Br
+[C@]
+[C@@]
+[N+]
+[O-]
+5
+\
+.
+I
+6
+[S@]
+[S@@]
+P
+[N-]
+[Si]
+7
+[n+]
+[2H]
+8
+[NH+]
+B
+9
+[C-]
+[Na+]
+[Cl-]
+[c-]
+[CH]
+%10
+[NH2+]
+[P+]
+[B]
+[I-]
+%11
+[CH2-]
+[O+]
+[NH3+]
+[C]
+[Br-]
+[IH2]
+[S-]
+[cH-]
+%12
+[nH+]
+[B-]
+[K+]
+[Sn]
+[Se]
+[CH-]
+[HH]
+[Y]
+[n-]
+[CH3-]
+[SiH]
+[S+]
+%13
+[SiH2]
+[Li+]
+[NH-]
+%14
+[Na]
+[CH2]
+[O-2]
+[U+2]
+[W]
+[Al]
+[P@]
+[Fe+2]
+[PH+]
+%15
+[Cl+3]
+[Zn+2]
+[Ir]
+[Mg+2]
+[Pt+2]
+[OH2+]
+[As]
+[Fe]
+[OH+]
+[Zr+2]
+[3H]
+[Ge]
+[SiH3]
+[OH-]
+[NH4+]
+[Cu+2]
+[P@@]
+p
+[Pt]
+%16
+[Ca+2]
+[Zr]
+[F-]
+[C+]
+[Ti]
+[P-]
+[V]
+[se]
+[U]
+[O]
+[Ni+2]
+[Zn]
+[Co]
+[Ni]
+[Pd+2]
+[Cu]
+%17
+[Cu+]
+[Te]
+[H+]
+[CH+]
+[Li]
+[Pd]
+[Mo]
+[Ru+2]
+[o+]
+[Re]
+[SH+]
+%18
+[Ac]
+[Cr]
+[NH2-]
+[K]
+[13CH2]
+[c]
+[Zr+4]
+[Tl]
+[13C]
+[Mn]
+[N@+]
+[Hg]
+[Rh]
+[Ti+4]
+[Sb]
+[Co+2]
+[Ag+]
+[Ru]
+%19
+[N@@+]
+[Ti+2]
+[Al+3]
+[Pb]
+[I+]
+[18F]
+[s+]
+[Rb+]
+[Ba+2]
+[H-]
+[Fe+3]
+[Ir+3]
+[13cH]
+%20
+[AlH2]
+[Au+]
+[13c]
+[SH2+]
+[Sn+2]
+[Mn+2]
+[Si-]
+[Ag]
+[N]
+[Bi]
+%21
+[In]
+[CH2+]
+[Y+3]
+[Ga]
+%22
+[Co+3]
+[Au]
+[13CH3]
+[Mg]
+[Cs+]
+[W+2]
+[Hf]
+[Zn+]
+[Se-]
+[S-2]
+[Ca]
+[pH]
+[ClH+]
+[Ti+3]
+%23
+[Ru+]
+[SH-]
+[13CH]
+[IH+]
+[Hf+4]
+[Rf]
+[OH3+]
+%24
+[Pt+4]
+[Zr+3]
+[PH3+]
+[Sr+2]
+[Cd+2]
+[Cd]
+%25
+[Os]
+[BH-]
+[Sn+4]
+[Cr+3]
+[Ru+3]
+[PH2+]
+[Rh+2]
+[V+2]
+%26
+[Gd+3]
+[Pb+2]
+[PH]
+[Hg+]
+[Mo+2]
+[AlH]
+[Sn+]
+%27
+[Pd+]
+b
+[Rh+3]
+[Hg+2]
+[15NH]
+[14C]
+%28
+[Mn+3]
+[Si+]
+[SeH]
+[13C@H]
+[NH]
+[Ga+3]
+[SiH-]
+[13C@@H]
+[Ce]
+[Au+3]
+[Bi+3]
+[15N]
+%29
+[BH3-]
+[14cH]
+[Ti+]
+[Gd]
+[cH+]
+[Cr+2]
+[Sb-]
+%30
+[Be+2]
+[Al+]
+[te]
+[11CH3]
+[Sm]
+[Pr]
+[La]
+%31
+[Al-]
+[Ta]
+[125I]
+[BH2-]
+[Nb]
+[Si@]
+%32
+[14c]
+[Sb+3]
+[Ba]
+%33
+[Os+2]
+[Si@@]
+[La+3]
+[15n]
+[15NH2]
+[Nd+3]
+%34
+[14CH2]
+[18O]
+[Nd]
+[GeH]
+[Ni+3]
+[Eu]
+[Dy+3]
+[Sc]
+%36
+[Se-2]
+[As+]
+%35
+[AsH]
+[Tb]
+[Sb+5]
+[Se+]
+[Ce+3]
+[c+]
+[In+3]
+[SnH]
+[Mo+4]
+%37
+[V+4]
+[Eu+3]
+[Hf+2]
+%38
+[Pt+]
+[p+]
+[123I]
+[Tl+]
+[Sm+3]
+%39
+[Yb+3]
+%40
+[Yb]
+[Os+]
+%41
+[10B]
+[Sc+3]
+[Al+2]
+%42
+[Sr]
+[Tb+3]
+[Po]
+[Tc]
+[PH-]
+[AlH3]
+[Ar]
+[U+4]
+[SnH2]
+[Cl+2]
+[si]
+[Fe+]
+[14CH3]
+[U+3]
+[Cl+]
+%43
+[GeH2]
+%44
+[Er+3]
+[Mo+3]
+[I+2]
+[Fe+4]
+[99Tc]
+%45
+[11C]
+%46
+[SnH3]
+[S]
+[Te+]
+[Er]
+[Lu+3]
+[11B]
+%47
+%48
+[P]
+[Tm]
+[Th]
+[Dy]
+[Pr+3]
+[Ta+5]
+[Nb+5]
+[Rb]
+[GeH3]
+[Br+2]
+%49
+[131I]
+[Fm]
+[Cs]
+[BH4-]
+[Lu]
+[15nH]
+%50
+[Ru+6]
+[b-]
+[Ho]
+[Th+4]
+[Ru+4]
+%52
+[14CH]
+%51
+[Cr+6]
+[18OH]
+[Ho+3]
+[Ce+4]
+[Bi+2]
+[Co+]
+%53
+[Yb+2]
+[Fe+6]
+[Be]
+%54
+[SH3+]
+[Np]
+[As-]
+%55
+[14C@@H]
+[Ir+2]
+[GaH3]
+[p-]
+[GeH4]
+[Sn+3]
+[Os+4]
+%56
+[14C@H]
+[sH+]
+[19F]
+[Eu+2]
+[TlH]
+%57
+[Cr+4]
+%58
+[B@@-]
+[SiH+]
+[At]
+[Am]
+[Fe+5]
+[AsH2]
+[Si+4]
+[B@-]
+[Pu]
+[SbH]
+[P-2]
+[Tm+3]
+*
+%59
+[se+]
+[IH-]
+%60
+[oH+]
+[1H]
+[15N+]
+[124I]
+[S@@+]
+[P-3]
+[H]
+[IH2+]
+[TeH]
+[Xe]
+[PH4+]
+[Cr+]
+[Cm]
+[I+3]
+%61
+[Nb+2]
+[Ru+5]
+%62
+[Ta+2]
+[Tc+4]
+[CH3+]
+[Pm]
+[Si@H]
+[No]
+%63
+[Cr+5]
+[Th+2]
+[Zn-2]
+[13C@]
+[Lr]
+%64
+[99Tc+3]
+%65
+[13C@@]
+%66
+[Fe-]
+[17O]
+[siH]
+[Sb+]
+[OH]
+[IH]
+[11CH2]
+[Cf]
+[SiH2+]
+[Gd+2]
+[In+]
+[Si@@H]
+[Mn+]
+[99Tc+4]
+[Ga-]
+%67
+[S@+]
+[Ge+4]
+[Tl+3]
+[16OH]
+%68
+[2H-]
+[Ra]
+[si-]
+[NiH2]
+[P@@H]
+[Rh+]
+[12C]
+[35S]
+[32P]
+[SiH2-]
+[AlH2+]
+[16O]
+%69
+[BiH]
+[BiH2]
+[Zn-]
+[BH]
+[Tc+3]
+[Ir+]
+[Ni+]
+%70
+[InH2]
+[InH]
+[Nb+3]
+[PbH]
+[Bi+]
+%71
+[As+3]
+%72
+[18O-]
+[68Ga+3]
+%73
+[Pa]
+[76Br]
+[Tc+5]
+[pH+]
+[64Cu+2]
+[Ru+8]
+%74
+[PH2-]
+[Si+2]
+[17OH]
+[RuH]
+[111In+3]
+[AlH+]
+%75
+%76
+[W+]
+[SbH2]
+[PoH]
+[Ru-]
+[XeH]
+[Tc+2]
+[13C-]
+[Br+]
+[Pt-2]
+[Es]
+[Cu-]
+[Mg+]
+[3HH]
+[P@H]
+[ClH2+]
+%77
+[SH]
+[Au-]
+[2HH]
+%78
+[Sn-]
+[11CH]
+[PdH2]
+0
+[Os+6]
+%79
+[Mo+]
+%80
+[al]
+[PbH2]
+[64Cu]
+[Cl]
+[12CH3]
+%81
+[Tc+7]
+[11c]
+%82
+[Li-]
+[99Tc+5]
+[He]
+[12c]
+[Kr]
+[RuH+2]
+[35Cl]
+[Pd-2]
+[GaH2]
+[4H]
+[Sg]
+[Cu-2]
+[Br+3]
+%83
+[37Cl]
+[211At]
+[IrH+2]
+[Mt]
+[Ir-2]
+[In-]
+[12cH]
+[12CH2]
+[RuH2]
+[99Tc+7]
+%84
+[15n+]
+[ClH2+2]
+[16N]
+[111In]
+[Tc+]
+[Ru-2]
+[12CH]
+[si+]
+[Tc+6]
+%85
+%86
+[90Y]
+[Pd-]
+[188Re]
+[RuH+]
+[NiH]
+[SiH3-]
+[14n]
+[CH3]
+[14N]
+[10BH2]
+%88
+%89
+%90
+[34S]
+[77Br]
+[GaH]
+[Br]
+[Ge@]
+[B@@H-]
+[CuH]
+[SiH4]
+[3H-]
+%87
+%91
+%92
+[67Cu]
+[I]
+[177Lu]
+[ReH]
+[67Ga+3]
+[Db]
+[177Lu+3]
+[AlH2-]
+[Si+3]
+[Ti-2]
+[RuH+3]
+[al+]
+[68Ga]
+[2H+]
+[B@H-]
+[WH2]
+[OsH]
+[Ir-3]
+[AlH-]
+[Bk]
+[75Se]
+[14C@]
+[Pt-]
+[N@@H+]
+[Nb-]
+[13NH2]
+%93
+[186Re]
+[Tb+4]
+[PtH]
+[IrH2]
+[Hg-2]
+[AlH3-]
+[PdH+]
+[Md]
+[RhH+2]
+[11cH]
+[Co-2]
+[15N-]
+[ZrH2]
+%94
+[Hg-]
+[127I]
+[AsH2+]
+[MoH2]
+[Te+4]
+[14C@@]
+[As+5]
+[SnH+3]
+[Ge@@]
+[6Li+]
+[WH]
+[Ne]
+[14NH2]
+[14NH]
+[12C@@H]
+[Os+7]
+[RhH]
+[Al-3]
+[SnH+]
+[15NH3+]
+[Zr+]
+[197Hg+]
+%95
+%96
+[90Y+3]
+[Os-2]
+[98Tc+5]
+[15NH3]
+[bH-]
+[33P]
+[Zr-2]
+[15O]
+[Rh-]
+[PbH3]
+[PH2]
+[Ni-]
+[CuH+]
+%97
+%98
+%99
+[Os+5]
+[PtH+]
+[ReH4]
+[16NH]
+[82Br]
+[W-]
+[18F-]
+[15NH4+]
+[Se+4]
+[SeH-]
+[SH4]
+[67Cu+2]
+[12C@H]
+[AsH3]
+[HgH]
+[10B-]
+[99Tc+6]
+[117Sn+4]
+[Te@]
+[P@+]
+[35SH]
+[SeH+]
+[Ni-2]
+[Al-2]
+[TeH2]
+[Bh]
+[99Tc+2]
+[Os+8]
+[PH-2]
+[7Li+]
+[14nH]
+[AlH+2]
+[18FH]
+[SnH4]
+[18O-2]
+[IrH]
+[13N]
+[Te@@]
+[Rh-3]
+[15NH+]
+[AsH3+]
+[SeH2]
+[AsH+]
+[CoH2]
+[16NH2]
+[AsH-]
+[203Hg+]
+[P@@+]
+[166Ho+3]
+[60Co+3]
+[13CH2-]
+[SeH2+]
+[75Br]
+[TlH2]
+[80Br]
+[siH+]
+[Ca+]
+[153Sm+3]
+[PdH]
+[225Ac]
+[13CH3-]
+[AlH4-]
+[FeH]
+[13CH-]
+[14C-]
+[11C-]
+[153Sm]
+[Re-]
+[te+]
+[13CH4]
+[ClH+2]
+[8CH2]
+[99Mo]
+[ClH3+3]
+[SbH3]
+[25Mg+2]
+[16N+]
+[SnH2+]
+[PH4]
+[11C@H]
+[122I]
+[Re-2]
+[RuH2+2]
+[ZrH]
+[Bi-]
+[Pr+]
+[Rn]
+[Fr]
+[36Cl]
+[18o]
+[YH]
+[79Br]
+[121I]
+[113In+3]
+[InH4-]
+[TaH]
+[RhH2]
+[Ta-]
+[67Ga]
+[ZnH+]
+[SnH2-]
+[OsH2]
+[16F]
+[FeH2]
+[14O]
+[PbH2+2]
+[BH2]
+[6H]
+[125Te]
+[197Hg]
+[TaH2]
+[TaH3]
+[76As]
+[Nb-2]
+[14N+]
+[125I-]
+[33S]
+[IH2+2]
+[NH2]
+[PtH2]
+[MnH]
+[19C]
+[17F]
+[1H-]
+[SnH4+2]
+[Mn-2]
+[15NH2+]
+[TiH2]
+[ReH7]
+[Cd-2]
+[Fe-3]
+[SH2]
+[17O-]
+[siH-]
+[CoH+]
+[VH]
+[10BH]
+[Ru-3]
+[13O]
+[5H]
+[CoH]
+[PH5]
+[15n-]
+[153Gd]
+[12C@]
+[11CH3-]
+[IrH3]
+[RuH3]
+[74Se]
+[Se@]
+[Hf+]
+[77Se]
+[166Ho]
+[59Fe+2]
+[203Hg]
+[18OH-]
+[8CH]
+[12C@@]
+[11CH4]
+[15C]
+[249Cf]
+[PbH4]
+[64Zn]
+[PH3]
+[99Tc+]
+[14c-]
+[149Pm]
+[IrH4]
+[Se@@]
+[13OH]
+[14CH3-]
+[28Si]
+[Rh-2]
+[Fe-2]
+[131I-]
+[51Cr]
+[62Cu+2]
+[81Br]
+[121Sb]
+[7Li]
+[89Zr+4]
+[SbH3+]
+[11C@@H]
+[98Tc]
+[59Fe+3]
+[BiH2+]
+[SbH+]
+[TiH]
+[14NH3]
+[15OH]
+[119Sn]
+[201Hg]
+[MnH+]
+[201Tl]
+[51Cr+3]
+[123I-]
+[MoH]
+[AlH6-3]
+[MnH2]
+[WH3]
+[213Bi+3]
+[SnH2+2]
+[123IH]
+[13CH+]
+[Zr-]
+[74As]
+[13C+]
+[32P+]
+[KrH]
+[SiH+2]
+[ClH3+2]
+[13NH]
+[9CH2]
+[ZrH2+2]
+[87Sr+2]
+[35s]
+[239Pu]
+[198Au]
+[241Am]
+[203Hg+2]
+[V+]
+[YH2]
+[SH5]
+[195Pt]
+[203Pb]
+[RuH4]
+[ThH2]
+[AuH]
+[66Ga+3]
+[11B-]
+[F]
+[24Na+]
+[85Sr+2]
+[201Tl+]
+[14CH4]
+[32S]
+[TeH2+]
+[ClH2+3]
+[AgH]
+[Ge@H]
+[44Ca+2]
+[Os-]
+[31P]
+[15nH+]
+[SbH4]
+[TiH+]
+[Ba+]
+[57Co+2]
+[Ta+]
+[125IH]
+[77As]
+[129I]
+[Fe-4]
+[Ta-2]
+[19O]
+[12O]
+[BiH3]
+[237Np]
+[252Cf]
+[86Y]
+[Cr-2]
+[89Y]
+[195Pt+2]
+[si+2]
+[58Fe+2]
+[Hs]
+[S@@H]
+[OsH6]
+[GdH2]
+[IH3]
+[8CH4]
+[164Dy+3]
+[47Ca+2]
+[57Co]
+[NbH2]
+[ReH2]
+[ZnH2]
+[CrH2]
+[17NH]
+[ZrH3]
+[RhH3]
+[12C-]
+[18O+]
+[Bi-2]
+[ClH4+3]
+[Ni-3]
+[Ag-]
+[111In-]
+[Mo-2]
+[55Fe+3]
+[204Hg+]
+[35Cl-]
+[211Pb]
+[75Ge]
+[8B]
+[TeH3]
+[SnH3+]
+[Zr-3]
+[28F]
+[249Bk]
+[169Yb]
+[34SH]
+[6Li]
+[94Tc]
+[197Au]
+[195Pt+4]
+[169Yb+3]
+[32Cl]
+[82Se]
+[159Gd+3]
+[213Bi]
+[CoH+2]
+[36S]
+[35P]
+[Ru-4]
+[Cr-3]
+[60Co]
+[1H+]
+[18CH2]
+[Cd-]
+[152Sm+3]
+[106Ru]
+[238Pu]
+[220Rn]
+[45Ca+2]
+[89Sr+2]
+[239Np]
+[90Sr+2]
+[137Cs+]
+[165Dy]
+[68GaH3]
+[65Zn+2]
+[89Zr]
+[BiH2+2]
+[62Cu]
+[165Dy+3]
+[238U]
+[105Rh+3]
+[70Zn]
+[12B]
+[12OH]
+[18CH]
+[17CH]
+[OsH3]
+[SbH-]
+[SH6]
+[AlH2-2]
+[42K]
+[76Br-]
+[71As]
+[NbH3]
+[ReH3]
+[OsH-]
+[WH4]
+[MoH3]
+[OsH4]
+[RuH6]
+[PtH3]
+[CuH2]
+[CoH3]
+[TiH4]
+[64Zn+2]
+[Si-2]
+[79BrH]
+[14CH2-]
+[PtH2+2]
+[Os-3]
+[29Si]
+[Ti-]
+[Se+6]
+[22Na+]
+[42K+]
+[131Cs+]
+[86Rb+]
+[134Cs+]
+[209Po]
+[208Po]
+[81Rb+]
+[203Tl+]
+[Zr-4]
+[148Sm]
+[147Sm]
+[37Cl-]
+[12CH4]
+[Ge@@H]
+[63Cu]
+[13CH2+]
+[AsH2-]
+[CeH]
+[SnH-]
+[UH]
+[9c]
+[21CH3]
+[TeH+]
+[57Co+3]
+[8BH2]
+[12BH2]
+[19BH2]
+[9BH2]
+[YbH2]
+[CrH+2]
+[208Bi]
+[152Gd]
+[61Cu]
+[115In]
+[60Co+2]
+[13NH2-]
+[120I]
+[18OH2]
+[75SeH]
+[SbH2+]
+[144Ce]
+[16n]
+[113In]
+[22nH]
+[129I-]
+[InH3]
+[32PH3]
+[234U]
+[235U]
+[59Fe]
+[82Rb+]
+[65Zn]
+[244Cm]
+[147Pm]
+[91Y]
+[237Pu]
+[231Pa]
+[253Cf]
+[127Te]
+[187Re]
+[236Np]
+[235Np]
+[72Zn]
+[253Es]
+[159Dy]
+[62Zn]
+[101Tc]
+[149Tb]
+[124I-]
+[SeH3+]
+[210Pb]
+[40K]
+[210Po]
+[214Pb]
+[218Po]
+[214Po]
+[7Be]
+[212Pb]
+[205Pb]
+[209Pb]
+[123Te]
+[202Pb]
+[72As]
+[201Pb]
+[70As]
+[73Ge]
+[200Pb]
+[198Pb]
+[66Ga]
+[73Se]
+[195Pb]
+[199Pb]
+[144Ce+3]
+[235U+2]
+[90Tc]
+[114In+3]
+[128I]
+[100Tc+]
+[82Br-]
+[191Pt+2]
+[191Pt+4]
+[193Pt+4]
+[31PH3]
+[125I+2]
+[131I+2]
+[125Te+4]
+[82Sr+2]
+[149Sm]
+[81BrH]
+[129Xe]
+[193Pt+2]
+[123I+2]
+[Cr-]
+[Co-]
+[227Th+4]
+[249Cf+3]
+[252Cf+3]
+[187Os]
+[16O-]
+[17O+]
+[16OH-]
+[98Tc+7]
+[58Co+2]
+[69Ga+3]
+[57Fe+2]
+[43K+]
+[16C]
+[52Fe+3]
+[SeH5]
+[194Pb]
+[196Pb]
+[197Pb]
+[213Pb]
+[9B]
+[19B]
+[11CH-]
+[9CH]
+[20OH]
+[25OH]
+[8cH]
+[TiH+3]
+[SnH6+3]
+[N@H+]
+[ZnH]
+[VH3]
+[52Mn+2]
+[64Ga]
+[13B]
+[216Bi]
+[117Sn+2]
+[232Th]
+[SnH+2]
+[BiH5]
+[77Kr]
+[103Cd]
+[62Ni]
+[LaH3]
+[SmH3]
+[EuH3]
+[MoH5]
+[64Ni]
+[66Zn]
+[68Zn]
+[186W]
+[FeH4]
+[MoH4]
+[HgH2]
+[15NH2-]
+[UH2]
+[204Hg]
+[GaH4-]
+[ThH4]
+[WH6]
+[PtH4]
+[VH2]
+[UH3]
+[FeH3]
+[RuH5]
+[BiH4]
+[80Br-]
+[CeH3]
+[37ClH]
+[157Gd+3]
+[205Tl]
+[203Tl]
+[62Cu+]
+[64Cu+]
+[61Cu+]
+[37SH2]
+[30Si]
+[28Al]
+[19OH2]
+[8He]
+[6He]
+[153Pm]
+[209Bi]
+[66Zn+2]
+[10CH4]
+[191Ir]
+[66Cu]
+[16O+]
+[25O]
+[10c]
+[Co-3]
+[Sn@@]
+[17OH-]
+[206Po]
+[204Po]
+[202Po]
+[201Po]
+[200Po]
+[199Po]
+[198Po]
+[197Po]
+[196Po]
+[195Po]
+[194Po]
+[193Po]
+[192Po]
+[191Po]
+[190Po]
+[217Po]
+[BiH4-]
+[TeH4]
+[222Ra]
+[62Ga]
+[39Ar]
+[144Sm]
+[58Fe]
+[153Eu]
+[85Rb]
+[171Yb]
+[172Yb]
+[114Cd]
+[51Fe]
+[142Ce]
+[207Tl]
+[92Mo]
+[115Sn]
+[140Ce]
+[202Hg]
+[180W]
+[182W]
+[183W]
+[184W]
+[96Mo]
+[47Ti]
+[111Cd]
+[143Nd]
+[145Nd]
+[126Te]
+[128Te]
+[130Te]
+[185Re]
+[97Mo]
+[98Mo]
+[183Re]
+[52V]
+[80Se]
+[87Kr]
+[137Xe]
+[196Au]
+[146Ce]
+[88Kr]
+[51Ti]
+[138Xe]
+[112Cd]
+[116Sn]
+[120Sn]
+[28SiH3]
+[35S-]
+[15NH-]
+[13CH3+]
+[34S+]
+[34s]
+[SiH4-]
+[100Tc+5]
+[NiH2+2]
+[239Th]
+[186Lu]
+[AuH3]
+[I@@-]
+[XeH2]
+[B+]
+[16CH2]
+[8C]
+[TaH5]
+[FeH4-]
+[19C@H]
+[10NH]
+[FeH6-3]
+[22CH]
+[25N]
+[25N+]
+[25N-]
+[21CH2]
+[18cH]
+[113I]
+[ScH3]
+[30PH3]
+[43Ca+2]
+[41Ca+2]
+[106Cd]
+[122Sn]
+[18CH3]
+[58Co+3]
+[98Tc+4]
+[70Ge]
+[76Ge]
+[108Cd]
+[116Cd]
+[130Xe]
+[94Mo]
+[124Sn]
+[186Os]
+[188Os]
+[190Os]
+[192Os]
+[106Pd]
+[110Pd]
+[120Te]
+[132Ba]
+[134Ba]
+[136Ba]
+[136Ce]
+[138Ce]
+[156Dy]
+[158Dy]
+[160Dy]
+[163Dy]
+[162Er]
+[164Er]
+[167Er]
+[176Hf]
+[26Mg]
+[144Nd]
+[150Nd]
+[41K]
+[46Ti]
+[48Ti]
+[49Ti]
+[50Ti]
+[170Yb]
+[173Yb]
+[91Zr]
+[92Zr]
+[96Zr]
+[34S-]
+[CuH2-]
+[38Cl]
+[25Mg]
+[51V]
+[93Nb]
+[95Mo]
+[45Sc]
+[123Sb]
+[139La]
+[9Be]
+[99Y+3]
+[99Y]
+[156Ho]
+[67Zn]
+[144Ce+4]
+[210Tl]
+[42Ca]
+[54Fe]
+[193Ir]
+[92Nb]
+[141Cs]
+[52Cr]
+[35ClH]
+[46Ca]
+[139Cs]
+[65Cu]
+[71Ga]
+[60Ni]
+[16NH3]
+[148Nd]
+[72Ge]
+[161Dy]
+[49Ca]
+[43Ca]
+[8Be]
+[48Ca]
+[44Ca]
+[120Xe]
+[80Rb]
+[215At]
+[180Re]
+[146Sm]
+[19Ne]
+[74Kr]
+[134La]
+[76Kr]
+[219Fr]
+[121Xe]
+[220Fr]
+[216At]
+[223Ac]
+[218At]
+[37Ar]
+[135I]
+[110Cd]
+[94Tc+7]
+[86Y+3]
+[135I-]
+[15O-2]
+[151Eu+3]
+[161Tb+3]
+[197Hg+2]
+[109Cd+2]
+[191Os+4]
+[170Tm+3]
+[205Bi+3]
+[233U+4]
+[126Sb+3]
+[127Sb+3]
+[132Cs+]
+[136Eu+3]
+[136Eu]
+[125Sn+4]
+[175Yb+3]
+[100Mo]
+[22Ne]
+[13c-]
+[13NH4+]
+[17C]
+[9C]
+[31S]
+[31SH]
+[133I]
+[126I]
+[36SH]
+[30S]
+[32SH]
+[19CH2]
+[19c]
+[18c]
+[15F]
+[10C]
+[RuH-]
+[62Zn+2]
+[32ClH]
+[33ClH]
+[78BrH]
+[12Li+]
+[12Li]
+[233Ra]
+[68Ge+4]
+[44Sc+3]
+[91Y+3]
+[106Ru+3]
+[PoH2]
+[AtH]
+[55Fe]
+[233U]
+[210PoH2]
+[230Th]
+[228Th]
+[222Rn]
+[35SH2]
+[227Th]
+[192Ir]
+[133Xe]
+[81Kr]
+[95Zr]
+[240Pu]
+[54Mn]
+[103Ru]
+[95Nb]
+[109Cd]
+[141Ce]
+[85Kr]
+[110Ag]
+[58Co]
+[241Pu]
+[234Th]
+[140La]
+[63Ni]
+[152Eu]
+[132IH]
+[226Rn]
+[154Eu]
+[36ClH]
+[228Ac]
+[155Eu]
+[106Rh]
+[243Am]
+[227Ac]
+[243Cm]
+[236U]
+[144Pr]
+[232U]
+[32SH2]
+[88Y]
+[82BrH]
+[135IH]
+[242Cm]
+[115Cd]
+[242Pu]
+[46Sc]
+[56Mn]
+[234Pa]
+[41Ar]
+[147Nd]
+[187W]
+[151Sm]
+[59Ni]
+[233Pa]
+[52Mn]
+[94Nb]
+[219Rn]
+[236Pu]
+[13NH3]
+[93Zr]
+[51Cr+6]
+[TlH3]
+[123Xe]
+[160Tb]
+[170Tm]
+[182Ta]
+[175Yb]
+[93Mo]
+[143Ce]
+[191Os]
+[126IH]
+[48V]
+[113Cd]
+[47Sc]
+[181Hf]
+[185W]
+[143Pr]
+[191Pt]
+[181W]
+[33PH3]
+[97Ru]
+[97Tc]
+[111Ag]
+[169Er]
+[107Pd]
+[103Ru+2]
+[34SH2]
+[137Ce]
+[242Am]
+[117SnH2]
+[57Ni]
+[239U]
+[60Cu]
+[250Cf]
+[193Au]
+[69Zn]
+[55Co]
+[139Ce]
+[127Xe]
+[159Gd]
+[56Co]
+[177Hf]
+[244Pu]
+[38ClH]
+[142Pr]
+[199Hg]
+[179Hf]
+[178Hf]
+[237U]
+[156Eu]
+[157Eu]
+[105Ru]
+[171Tm]
+[199Au]
+[155Sm]
+[80BrH]
+[108Ag]
+[128IH]
+[48Sc]
+[45Ti]
+[176Lu]
+[121SnH2]
+[148Pm]
+[57Fe]
+[10BH3]
+[96Tc]
+[133IH]
+[143Pm]
+[105Rh]
+[130IH]
+[134IH]
+[131IH]
+[71Zn]
+[105Ag]
+[97Zr]
+[235Pu]
+[231Th]
+[109Pd]
+[93Y]
+[190Ir]
+[135Xe]
+[53Mn]
+[134Ce]
+[234Np]
+[240Am]
+[246Cf]
+[240Cm]
+[241Cm]
+[226Th]
+[39ClH]
+[229Th]
+[245Cm]
+[240U]
+[240Np]
+[249Cm]
+[243Pu]
+[145Pm]
+[199Pt]
+[246Bk]
+[193Pt]
+[230U]
+[250Cm]
+[44Ti]
+[175Hf]
+[254Fm]
+[255Fm]
+[257Fm]
+[92Y]
+[188Ir]
+[171Lu]
+[257Md]
+[247Bk]
+[121IH]
+[250Bk]
+[179Lu]
+[224Ac]
+[195Hg]
+[244Am]
+[246Pu]
+[194Au]
+[252Fm]
+[173Hf]
+[246Cm]
+[135Ce]
+[49Cr]
+[248Cf]
+[247Cm]
+[248Cm]
+[174Ta]
+[176Ta]
+[154Tb]
+[172Ta]
+[177Ta]
+[175Ta]
+[180Ta]
+[158Tb]
+[115Ag]
+[189Os]
+[251Cf]
+[145Pr]
+[147Pr]
+[76BrH]
+[102Rh]
+[238Np]
+[185Os]
+[246Am]
+[233Np]
+[166Dy]
+[254Es]
+[244Cf]
+[193Os]
+[245Am]
+[245Bk]
+[239Am]
+[238Am]
+[97Nb]
+[245Pu]
+[254Cf]
+[188W]
+[250Es]
+[251Es]
+[237Am]
+[182Hf]
+[258Md]
+[232Np]
+[238Cm]
+[60Fe]
+[109Pd+2]
+[234Pu]
+[141Ce+3]
+[136Nd]
+[136Pr]
+[173Ta]
+[110Ru]
+[147Tb]
+[253Fm]
+[139Nd]
+[178Re]
+[177Re]
+[200Au]
+[182Re]
+[156Tb]
+[155Tb]
+[157Tb]
+[161Tb]
+[161Ho]
+[167Tm]
+[173Lu]
+[179Ta]
+[171Er]
+[44Sc]
+[49Sc]
+[49V]
+[51Mn]
+[90Nb]
+[88Nb]
+[88Zr]
+[36SH2]
+[174Yb]
+[178Lu]
+[179W]
+[83BrH]
+[107Cd]
+[75BrH]
+[62Co]
+[48Cr]
+[63Zn]
+[102Ag]
+[154Sm]
+[168Er]
+[65Ni]
+[137La]
+[187Ir]
+[144Pm]
+[146Pm]
+[160Gd]
+[166Yb]
+[162Dy]
+[47V]
+[141Nd]
+[141Sm]
+[166Er]
+[150Sm]
+[146Eu]
+[149Eu]
+[174Lu]
+[17NH3]
+[102Ru]
+[170Hf]
+[188Pt]
+[61Ni]
+[56Ni]
+[149Gd]
+[151Gd]
+[141Pm]
+[147Gd]
+[146Gd]
+[161Er]
+[103Ag]
+[145Eu]
+[153Tb]
+[155Dy]
+[184Re]
+[180Os]
+[182Os]
+[186Pt]
+[181Os]
+[181Re]
+[151Tb]
+[178Ta]
+[178W]
+[189Pt]
+[194Hg]
+[145Sm]
+[150Tb]
+[132La]
+[158Gd]
+[104Ag]
+[193Hg]
+[94Ru]
+[137Pr]
+[155Ho]
+[117Cd]
+[99Ru]
+[146Nd]
+[218Rn]
+[95Y]
+[79Kr]
+[120IH]
+[138Pr]
+[100Pd]
+[166Tm]
+[90Mo]
+[151Nd]
+[231U]
+[138Nd]
+[89Nb]
+[98Nb]
+[162Ho]
+[142Sm]
+[186Ta]
+[104Tc]
+[184Ta]
+[185Ta]
+[170Er]
+[107Rh]
+[131La]
+[169Lu]
+[74BrH]
+[150Pm]
+[172Tm]
+[197Pt]
+[230Pu]
+[170Lu]
+[86Zr]
+[176W]
+[177W]
+[101Pd]
+[105Pd]
+[108Pd]
+[149Nd]
+[164Ho]
+[159Ho]
+[167Ho]
+[176Yb]
+[156Sm]
+[77BrH]
+[189Re]
+[99Rh]
+[100Rh]
+[151Pm]
+[232Pa]
+[228Pa]
+[230Pa]
+[66Ni]
+[194Os]
+[135La]
+[138La]
+[141La]
+[142La]
+[195Ir]
+[96Nb]
+[157Ho]
+[183Hf]
+[162Tm]
+[172Er]
+[148Eu]
+[150Eu]
+[15CH4]
+[89Kr]
+[143La]
+[58Ni]
+[61Co]
+[158Eu]
+[165Er]
+[167Yb]
+[173Tm]
+[175Tm]
+[172Hf]
+[172Lu]
+[93Tc]
+[177Yb]
+[124IH]
+[194Ir]
+[147Eu]
+[101Mo]
+[180Hf]
+[189Ir]
+[87Y]
+[43Sc]
+[195Au]
+[112Ag]
+[84BrH]
+[106Ag]
+[109Ag]
+[101Rh]
+[162Yb]
+[228Rn]
+[139Pr]
+[94Y]
+[201Au]
+[40PH3]
+[110Ag+]
+[104Cd]
+[133Ba+2]
+[226Ac]
+[145Gd]
+[186Ir]
+[184Ir]
+[224Rn]
+[185Ir]
+[182Ir]
+[184Hf]
+[200Pt]
+[227Pa]
+[178Yb]
+[72Br-]
+[72BrH]
+[248Am]
+[238Th]
+[161Gd]
+[35S-2]
+[107Ag]
+[FeH6-4]
+[89Sr]
+[SnH3-]
+[SeH3]
+[TeH3+]
+[SbH4+]
+[AsH4+]
+[4He]
+[AsH3-]
+[1HH]
+[3H+]
+[82Rb]
+[85Sr]
+[90Sr]
+[137Cs]
+[133Ba]
+[131Cs]
+[SbH5]
+[224Ra]
+[22Na]
+[210Bi]
+[214Bi]
+[228Ra]
+[127Sb]
+[136Cs]
+[125Sb]
+[134Cs]
+[140Ba]
+[45Ca]
+[206Pb]
+[207Pb]
+[24Na]
+[86Rb]
+[212Bi]
+[208Pb]
+[124Sb]
+[204Pb]
+[44K]
+[129Te]
+[113Sn]
+[204Tl]
+[87Sr]
+[208Tl]
+[87Rb]
+[47Ca]
+[135Cs]
+[216Po]
+[137Ba]
+[207Bi]
+[212Po]
+[79Se]
+[223Ra]
+[86Sr]
+[122Sb]
+[26Al]
+[32Si]
+[126Sn]
+[225Ra]
+[114In]
+[72Ga]
+[132Te]
+[10Be]
+[125Sn]
+[73As]
+[206Bi]
+[117Sn]
+[40Ca]
+[41Ca]
+[89Rb]
+[116In]
+[129Sb]
+[91Sr]
+[71Ge]
+[139Ba]
+[69Ga]
+[120Sb]
+[121Sn]
+[123Sn]
+[131Te]
+[77Ge]
+[135Ba]
+[82Sr]
+[43K]
+[131Ba]
+[92Sr]
+[88Rb]
+[129Cs]
+[144Cs]
+[127Cs]
+[200Tl]
+[202Tl]
+[141Ba]
+[117Sb]
+[116Sb]
+[78As]
+[131Sb]
+[126Sb]
+[128Sb]
+[130Sb]
+[67Ge]
+[68Ge]
+[78Ge]
+[66Ge]
+[223Fr]
+[132Cs]
+[125Cs]
+[138Cs]
+[133Te]
+[84Rb]
+[83Rb]
+[81Rb]
+[142Ba]
+[200Bi]
+[115Sb]
+[194Tl]
+[70Se]
+[112In]
+[118Sb]
+[70Ga]
+[27Mg]
+[202Bi]
+[83Se]
+[9Li]
+[69As]
+[79Rb]
+[81Sr]
+[83Sr]
+[78Se]
+[109In]
+[29Al]
+[118Sn]
+[117In]
+[119Sb]
+[114Sn]
+[138Ba]
+[69Ge]
+[73Ga]
+[74Ge]
+[206Tl]
+[199Tl]
+[130Cs]
+[28Mg]
+[116Te]
+[112Sn]
+[126Ba]
+[211Bi]
+[81Se]
+[127Sn]
+[143Cs]
+[134Te]
+[80Sr]
+[45K]
+[215Po]
+[207Po]
+[111Sn]
+[211Po]
+[128Ba]
+[198Tl]
+[227Ra]
+[213Po]
+[220Ra]
+[128Sn]
+[203Po]
+[205Po]
+[65Ga]
+[197Tl]
+[88Sr]
+[110In]
+[31Si]
+[201Bi]
+[121Te]
+[205Bi]
+[203Bi]
+[195Tl]
+[209Tl]
+[110Sn]
+[222Fr]
+[207At]
+[119In]
+[As@]
+[129IH]
+[157Dy]
+[111IH]
+[230Ra]
+[144Pr+3]
+[SiH3+]
+[3He]
+[AsH5]
+[72Se]
+[95Tc]
+[103Pd]
+[121Sn+2]
+[211Rn]
+[38SH2]
+[127IH]
+[74Br-]
+[133I-]
+[100Tc+4]
+[100Tc]
+[36Cl-]
+[89Y+3]
+[104Rh]
+[152Sm]
+[226Ra]
+[19FH]
+[104Pd]
+[148Gd]
+[157Lu]
+[33SH2]
+[121I-]
+[17FH]
+[71Se]
+[157Sm]
+[148Tb]
+[164Dy]
+[15OH2]
+[15O+]
+[39K]
+[40Ar]
+[50Cr+3]
+[50Cr]
+[52Ti]
+[103Pd+2]
+[130Ba]
+[142Pm]
+[153Gd+3]
+[151Eu]
+[103Rh]
+[124Xe]
+[152Tb]
+[17OH2]
+[20Ne]
+[52Fe]
+[94Zr+4]
+[94Zr]
+[149Pr]
+[16OH2]
+[53Cr+6]
+[53Cr]
+[81Br-]
+[112Pd]
+[125Xe]
+[155Gd]
+[157Gd]
+[168Yb]
+[184Os]
+[166Tb]
+[221Fr]
+[212Ra]
+[75Br-]
+[79Br-]
+[113Ag]
+[23Na]
+[34Cl-]
+[34ClH]
+[38Cl-]
+[56Fe]
+[68Cu]
+[77Br-]
+[90Zr+4]
+[90Zr]
+[102Pd]
+[154Eu+3]
+[57Mn]
+[165Tm]
+[152Dy]
+[217At]
+[77se]
+[13cH-]
+[122Te]
+[156Gd]
+[124Te]
+[53Ni]
+[131Xe]
+[174Hf+4]
+[174Hf]
+[76Se]
+[168Tm]
+[167Dy]
+[154Gd]
+[95Ru]
+[210At]
+[85Br]
+[59Co]
+[122Xe]
+[27Al]
+[54Cr]
+[198Hg]
+[85Rb+]
+[214Tl]
+[229Rn]
+[218Pb]
+[218Bi]
+[167Tm+3]
+[18o+]
+[P@@H+]
+[P@H+]
+[13N+]
+[212Pb+2]
+[217Bi]
+[249Cf+2]
+[18OH3+]
+[90Sr-]
+[Cf+3]
+[200Hg]
+[86Tc]
+[141Pr+3]
+[141Pr]
+[16nH]
+[14NH4+]
+[132Xe]
+[83Kr]
+[70Zn+2]
+[137Ba+2]
+[36Ar]
+[38Ar]
+[21Ne]
+[126Xe]
+[136Xe]
+[128Xe]
+[134Xe]
+[84Kr]
+[86Kr]
+[78Kr]
+[80Kr]
+[82Kr]
+[67Zn+2]
+[65Cu+2]
+[110Te]
+[58Fe+3]
+[142Nd]
+[38K]
+[198Au+3]
+[122IH]
+[38PH3]
+[130I-]
+[40K+]
+[38K+]
+[28Mg+2]
+[208Tl+]
+[13OH2]
+[198Bi]
+[192Bi]
+[194Bi]
+[196Bi]
+[132I-]
+[83Sr+2]
+[169Er+3]
+[122I-]
+[120I-]
+[92Sr+2]
+[126I-]
+[24Mg]
+[84Sr]
+[118Pd+2]
+[118Pd]
+[AsH4]
+[127I-]
+[9C-]
+[11CH3+]
+[17B]
+[7B]
+[4HH]
+[18C-]
+[22CH3-]
+[22CH4]
+[17C-]
+[15CH3]
+[16CH3]
+[11NH3]
+[21NH3]
+[11N-]
+[11NH]
+[16CH]
+[17CH2]
+[99Ru+2]
+[181Ta+2]
+[181Ta]
+[20CH]
+[32PH2]
+[55Fe+2]
+[SH3]
+[S@H]
+[Mn-]
+[IH4]
+[ThH]
+[GaH-]
+[BiH+]
+[EuH2]
+[FeH4-3]
+[FeH6]
+[IH5]
+[NiH+]
+[SrH2]
+[VH4]
+[YH3]
+[seH+]
+
diff --git a/models/smi_ted/inference/smi_ted_light/load.py b/models/smi_ted/inference/smi_ted_light/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..54d8422b6e54ca164f42c8ea6bc9e84b8c3e3103
--- /dev/null
+++ b/models/smi_ted/inference/smi_ted_light/load.py
@@ -0,0 +1,672 @@
+PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+
+# Transformers
+from fast_transformers.attention import AttentionLayer
+from fast_transformers.events import QKVEvent
+from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer
+from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder
+from fast_transformers.builders.attention_builders import AttentionBuilder
+from fast_transformers.feature_maps import GeneralizedRandomFeatures
+from fast_transformers.masking import LengthMask
+from transformers import BertTokenizer
+
+# Data
+import numpy as np
+import pandas as pd
+
+# Chemistry
+from rdkit import Chem
+from rdkit.Chem import PandasTools
+from rdkit.Chem import Descriptors
+PandasTools.RenderImagesInAllDataFrames(True)
+
+# Standard library
+from functools import partial
+import regex as re
+import random
+import os
+import gc
+from tqdm import tqdm
+tqdm.pandas()
+
+
+# function to canonicalize SMILES
+def normalize_smiles(smi, canonical=True, isomeric=False):
+ try:
+ normalized = Chem.MolToSmiles(
+ Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
+ )
+ except:
+ normalized = None
+ return normalized
+
+
+class MolTranBertTokenizer(BertTokenizer):
+ def __init__(self, vocab_file: str = '',
+ do_lower_case=False,
+ unk_token='',
+ sep_token='',
+ pad_token='',
+ cls_token='',
+ mask_token='',
+ **kwargs):
+ super().__init__(vocab_file,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs)
+
+ self.regex_tokenizer = re.compile(PATTERN)
+ self.wordpiece_tokenizer = None
+ self.basic_tokenizer = None
+ with open(vocab_file) as f:
+ self.padding_idx = f.readlines().index(pad_token+'\n')
+
+ def _tokenize(self, text):
+ split_tokens = self.regex_tokenizer.findall(text)
+ return split_tokens
+
+ def convert_idx_to_tokens(self, idx_tensor):
+ tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
+ return tokens
+
+ def convert_tokens_to_string(self, tokens):
+ stopwords = ['', '']
+ clean_tokens = [word for word in tokens if word not in stopwords]
+ out_string = ''.join(clean_tokens)
+ return out_string
+
+ def get_padding_idx(self):
+ return self.padding_idx
+
+ def idx_to_smiles(self, torch_model, idx):
+ '''Convert tokens idx back to SMILES text'''
+ rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
+ flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
+ decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
+ return decoded_smiles
+
+
+## Transformer layers
+class RotaryEmbedding(torch.nn.Module):
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = 0
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self.cos_cached = emb.cos()[None,:, None, :]
+ self.sin_cached = emb.sin()[None,:, None, :]
+
+ return self.cos_cached, self.sin_cached
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+@torch.jit.script
+def apply_rotary_pos_emb(q, k, cos, sin):
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+class RotateAttentionLayer(AttentionLayer):
+ """Rotate attention layer inherits from fast_transformer attention layer.
+ The only thing added is an Embedding encoding, for more information
+ on the attention layer see the fast_transformers code
+ """
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
+ d_values=None, event_dispatcher=""):
+ super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys,
+ d_values=d_values, event_dispatcher=event_dispatcher)
+
+ self.rotaryemb = RotaryEmbedding(d_keys)
+ print('Using Rotation Embedding')
+
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
+ key_lengths):
+ """
+ Using the same frame work as the fast_Transformers attention layer
+ but injecting rotary information to the queries and the keys
+ after the keys and queries are projected.
+ In the argument description we make use of the following sizes
+ - N: the batch size
+ - L: The maximum length of the queries
+ - S: The maximum length of the keys (the actual length per sequence
+ is given by the length mask)
+ - D: The input feature dimensionality passed in the constructor as
+ 'd_model'
+ Arguments
+ ---------
+ queries: (N, L, D) The tensor containing the queries
+ keys: (N, S, D) The tensor containing the keys
+ values: (N, S, D) The tensor containing the values
+ attn_mask: An implementation of BaseMask that encodes where each
+ query can attend to
+ query_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ key_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ Returns
+ -------
+ The new value for each query as a tensor of shape (N, L, D).
+ """
+ # Extract the dimensions into local variables
+ N, L, _ = queries.shape
+ _, S, _ = keys.shape
+ H = self.n_heads
+
+ # Project the queries/keys/values
+ queries = self.query_projection(queries).view(N, L, H, -1)
+ keys = self.key_projection(keys).view(N, S, H, -1)
+ cos, sin = self.rotaryemb(queries)
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
+ values = self.value_projection(values).view(N, S, H, -1)
+ # Let the world know of the qkv
+ self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
+
+
+ # Compute the attention
+ new_values = self.inner_attention(
+ queries,
+ keys,
+ values,
+ attn_mask,
+ query_lengths,
+ key_lengths
+ ).view(N, L, -1)
+
+ # Project the output and return
+ return self.out_projection(new_values)
+
+class RotateEncoderBuilder(BaseTransformerEncoderBuilder):
+ """Build a batch transformer encoder with Relative Rotary embeddings
+ for training or processing of sequences all elements at a time.
+ Example usage:
+ builder = RotateEncoderBuilder()
+ builder.n_layers = 12
+ builder.n_heads = 8
+ builder.feed_forward_dimensions = 1024
+ builder.query_dimensions = 64
+ builder.value_dimensions = 64
+ builder.dropout = 0.1
+ builder.attention_dropout = 0.1
+ builder.attention_type = "linear"
+ transformer = builder.get()
+ """
+ def _get_attention_builder(self):
+ """Return an instance of the appropriate attention builder."""
+ return AttentionBuilder()
+
+ def _get_attention_layer_class(self):
+ """Return the class for the layer that projects queries keys and
+ values."""
+ return RotateAttentionLayer
+
+ def _get_encoder_class(self):
+ """Return the class for the transformer encoder."""
+ return TransformerEncoder
+
+ def _get_encoder_layer_class(self):
+ """Return the class for the transformer encoder layer."""
+ return TransformerEncoderLayer
+
+
+class AutoEncoderLayer(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.encoder = self.Encoder(feature_size, latent_size)
+ self.decoder = self.Decoder(feature_size, latent_size)
+
+ class Encoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(feature_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.lat = nn.Linear(latent_size, latent_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.lat.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.lat(x)
+ return x # -> (N, D)
+
+ class Decoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(latent_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.rec = nn.Linear(latent_size, feature_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.rec.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.rec(x)
+ return x # -> (N, L*D)
+
+
+class LangLayer(nn.Module):
+
+ def __init__(self, n_embd, n_vocab):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.embed = nn.Linear(n_embd, n_embd)
+ self.ln_f = nn.LayerNorm(n_embd)
+ self.head = nn.Linear(n_embd, n_vocab, bias=False)
+
+ def forward(self, tensor):
+ if self.is_cuda_available:
+ self.embed.cuda()
+ self.ln_f.cuda()
+ self.head.cuda()
+ tensor = tensor.cuda()
+ tensor = self.embed(tensor)
+ tensor = F.gelu(tensor)
+ tensor = self.ln_f(tensor)
+ tensor = self.head(tensor)
+ return tensor
+
+
+class Net(nn.Module):
+
+ def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2):
+ super().__init__()
+ self.desc_skip_connection = True
+ self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.relu1 = nn.GELU()
+ self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
+ self.dropout2 = nn.Dropout(dropout)
+ self.relu2 = nn.GELU()
+ self.final = nn.Linear(smiles_embed_dim, n_output)
+
+ def forward(self, smiles_emb, multitask=False):
+ x_out = self.fc1(smiles_emb)
+ x_out = self.dropout1(x_out)
+ x_out = self.relu1(x_out)
+
+ if self.desc_skip_connection is True:
+ x_out = x_out + smiles_emb
+
+ z = self.fc2(x_out)
+ z = self.dropout2(z)
+ z = self.relu2(z)
+ if self.desc_skip_connection is True:
+ z = self.final(z + x_out)
+ else:
+ z = self.final(z)
+
+ if multitask:
+ return F.sigmoid(z)
+ return z
+
+
+class MoLEncoder(nn.Module):
+
+ def __init__(self, config, n_vocab):
+ super(MoLEncoder, self).__init__()
+
+ # embeddings
+ self.config = config
+ self.tok_emb = nn.Embedding(n_vocab, config['n_embd'])
+ self.drop = nn.Dropout(config['d_dropout'])
+
+ # transformer
+ builder = RotateEncoderBuilder.from_kwargs(
+ n_layers=config['n_layer'],
+ n_heads=config['n_head'],
+ query_dimensions=config['n_embd']//config['n_head'],
+ value_dimensions=config['n_embd']//config['n_head'],
+ feed_forward_dimensions=config['n_embd'],
+ attention_type='linear',
+ # unless we do deterministic_eval here, we will have random outputs
+ feature_map=partial(GeneralizedRandomFeatures,
+ n_dims=config['num_feats'],
+ deterministic_eval=True),
+ activation='gelu'
+ )
+ self.blocks = builder.get()
+
+ # classification
+ self.lang_model = LangLayer(config['n_embd'], n_vocab)
+
+ def forward(self, idx, mask):
+ # transformer encoder
+ x = self.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.drop(x)
+ x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1]))
+
+ # add padding
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ mask_embeddings = (token_embeddings * input_mask_expanded)
+ token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0)
+
+ return token_embeddings
+
+
+class MoLDecoder(nn.Module):
+
+ def __init__(self, n_vocab, max_len, n_embd, n_gpu=None):
+ super(MoLDecoder, self).__init__()
+
+ self.max_len = max_len
+ self.n_embd = n_embd
+ self.n_gpu = n_gpu
+ self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd)
+ self.lang_model = LangLayer(n_embd, n_vocab)
+
+
+class Smi_ted(nn.Module):
+ """materials.smi-ted-Light 289M Parameters"""
+
+ def __init__(self, tokenizer, config=None):
+ super(Smi_ted, self).__init__()
+
+ # configuration
+ self.config = config
+ self.tokenizer = tokenizer
+ self.padding_idx = tokenizer.get_padding_idx()
+ self.n_vocab = len(self.tokenizer.vocab)
+ self.is_cuda_available = torch.cuda.is_available()
+
+ # instantiate modules
+ if self.config:
+ self.encoder = MoLEncoder(self.config, self.n_vocab)
+ self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
+ self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['d_dropout'])
+
+ def load_checkpoint(self, ckpt_path):
+ # load checkpoint file
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
+
+ # load hyparameters
+ self.config = checkpoint['hparams']
+ self.max_len = self.config['max_len']
+ self.n_embd = self.config['n_embd']
+ self._set_seed(self.config['seed'])
+
+ # instantiate modules
+ self.encoder = MoLEncoder(self.config, self.n_vocab)
+ self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
+ self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else 1, dropout=self.config['d_dropout'])
+
+ # load weights
+ if 'state_dict' in checkpoint:
+ if isinstance(checkpoint['state_dict'], list):
+ self.encoder.load_state_dict(checkpoint['state_dict'][0], strict=False)
+ self.decoder.load_state_dict(checkpoint['state_dict'][1], strict=False)
+ else:
+ self.load_state_dict(checkpoint['state_dict'], strict=False)
+ elif 'MODEL_STATE' in checkpoint:
+ self.load_state_dict(checkpoint['MODEL_STATE'], strict=False)
+
+ # load RNG states each time the model and states are loaded from checkpoint
+ if 'rng' in self.config:
+ rng = self.config['rng']
+ for key, value in rng.items():
+ if key =='torch_state':
+ torch.set_rng_state(value.cpu())
+ elif key =='cuda_state':
+ torch.cuda.set_rng_state(value.cpu())
+ elif key =='numpy_state':
+ np.random.set_state(value)
+ elif key =='python_state':
+ random.setstate(value)
+ else:
+ print('unrecognized state')
+
+ def _set_seed(self, value):
+ print('Random Seed:', value)
+ random.seed(value)
+ torch.manual_seed(value)
+ torch.cuda.manual_seed(value)
+ torch.cuda.manual_seed_all(value)
+ np.random.seed(value)
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+
+ def forward(self, smiles, batch_size=100):
+ return self.decode(self.encode(smiles, batch_size=batch_size, return_torch=True))
+
+ def tokenize(self, smiles):
+ """Tokenize a string into tokens."""
+ if isinstance(smiles, str):
+ batch = [smiles]
+ else:
+ batch = smiles
+
+ tokens = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ max_length=self.max_len,
+ )
+
+ idx = tokens['input_ids'].clone().detach()
+ mask = tokens['attention_mask'].clone().detach()
+
+ if self.is_cuda_available:
+ return idx.cuda(), mask.cuda()
+
+ return idx, mask
+
+ def extract_all(self, smiles):
+ """Extract all elements from each part of smi-ted. Be careful."""
+ # evaluation mode
+ self.encoder.eval()
+ self.decoder.eval()
+ if self.is_cuda_available:
+ self.encoder.cuda()
+ self.decoder.cuda()
+
+ # handle single str or a list of str
+ smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
+
+ # SMILES normalization
+ smiles = smiles.apply(normalize_smiles)
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
+ smiles = smiles.dropna()
+
+ # tokenizer
+ idx, mask = self.tokenize(smiles.to_list())
+
+ ###########
+ # Encoder #
+ ###########
+ # encoder forward
+ x = self.encoder.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.encoder.drop(x)
+ x = self.encoder.blocks(x, length_mask=LengthMask(mask.sum(-1)))
+
+ # mean pooling
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+ true_set = sum_embeddings / sum_mask # DO NOT USE THIS FOR DOWNSTREAM TASKS, USE `pred_set` INSTEAD
+
+ # add padding
+ mask_embeddings = (token_embeddings * input_mask_expanded)
+ token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.max_len - mask_embeddings.shape[1]), value=0)
+ idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=2)
+
+ true_ids = idx
+ true_cte = token_embeddings
+ true_cte = true_cte.view(-1, self.max_len*self.n_embd)
+
+ ###########
+ # Decoder #
+ ###########
+ # CTE autoencoder
+ pred_set = self.decoder.autoencoder.encoder(true_cte)
+ pred_cte = self.decoder.autoencoder.decoder(pred_set)
+
+ # reconstruct tokens
+ pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
+ pred_ids = torch.argmax(pred_ids, axis=-1)
+
+ # replacing null SMILES with NaN values
+ for idx in null_idx:
+ true_ids = true_ids.tolist()
+ pred_ids = pred_ids.tolist()
+ true_cte = true_cte.tolist()
+ pred_cte = pred_cte.tolist()
+ true_set = true_set.tolist()
+ pred_set = pred_set.tolist()
+
+ true_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
+ pred_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
+ true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
+ pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
+ true_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
+ pred_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
+
+ if len(null_idx) > 0:
+ true_ids = torch.tensor(true_ids)
+ pred_ids = torch.tensor(pred_ids)
+ true_cte = torch.tensor(true_cte)
+ pred_cte = torch.tensor(pred_cte)
+ true_set = torch.tensor(true_set)
+ pred_set = torch.tensor(pred_set)
+
+ return ((true_ids, pred_ids), # tokens
+ (true_cte, pred_cte), # token embeddings
+ (true_set, pred_set)) # smiles embeddings
+
+ def extract_embeddings(self, smiles):
+ """Extract token and SMILES embeddings."""
+ # evaluation mode
+ self.encoder.eval()
+ if self.is_cuda_available:
+ self.encoder.cuda()
+
+ # tokenizer
+ idx, mask = self.tokenize(smiles)
+
+ # encoder forward
+ token_embeddings = self.encoder(idx, mask)
+
+ # aggregate token embeddings (similar to mean pooling)
+ # CAUTION: use the embeddings from the autoencoder.
+ smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd))
+
+ # add padding
+ idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=self.padding_idx)
+
+ return idx, token_embeddings, smiles_embeddings
+
+ def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False):
+ """Extract efficiently SMILES embeddings per batches."""
+ # TODO: remove useCuda argument
+
+ # handle single str or a list of str
+ smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
+
+ # SMILES normalization
+ smiles = smiles.apply(normalize_smiles)
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
+ smiles = smiles.dropna()
+
+ # process in batches
+ n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
+ embeddings = [
+ self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
+ for batch in tqdm(np.array_split(smiles, n_split))
+ ]
+ flat_list = [item for sublist in embeddings for item in sublist]
+
+ # clear GPU memory
+ if self.is_cuda_available:
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # replacing null SMILES with NaN values
+ for idx in null_idx:
+ flat_list.insert(idx, np.array([np.nan]*self.config['n_embd']))
+ flat_list = np.asarray(flat_list)
+
+ if return_torch:
+ return torch.tensor(flat_list)
+ return pd.DataFrame(flat_list)
+
+ def decode(self, smiles_embeddings):
+ """Decode SMILES embeddings back to SMILES."""
+ # evaluation mode
+ self.decoder.eval()
+ if self.is_cuda_available:
+ self.decoder.cuda()
+
+ # reconstruct token embeddings
+ pred_token_embds = self.decoder.autoencoder.decoder(smiles_embeddings)
+
+ # reconstruct tokens
+ pred_idx = self.decoder.lang_model(pred_token_embds.view(-1, self.max_len, self.n_embd))
+ pred_idx = torch.argmax(pred_idx, axis=-1).cpu().detach().numpy()
+
+ # convert idx to tokens
+ pred_smiles = []
+ for i in range(pred_idx.shape[0]):
+ idx = pred_idx[i]
+ smiles = self.tokenizer.idx_to_smiles(self, idx)
+ smiles = smiles.replace('', '') # begin token
+ smiles = smiles.replace('', '') # end token
+ smiles = smiles.replace('', '') # pad token
+ pred_smiles.append(smiles)
+
+ # clear GPU memory
+ if self.is_cuda_available:
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ return pred_smiles
+
+ def __str__(self):
+ return 'smi-ted-Light'
+
+
+def load_smi_ted(folder="./smi_ted_light",
+ ckpt_filename="smi-ted-Light_40.pt",
+ vocab_filename="bert_vocab_curated.txt"
+ ):
+ tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
+ model = Smi_ted(tokenizer)
+ model.load_checkpoint(os.path.join(folder, ckpt_filename))
+ model.eval()
+ print('Vocab size:', len(tokenizer.vocab))
+ print(f'[INFERENCE MODE - {str(model)}]')
+ return model
\ No newline at end of file
diff --git a/models/smi_ted/notebooks/data/moses_test.csv b/models/smi_ted/notebooks/data/moses_test.csv
new file mode 100644
index 0000000000000000000000000000000000000000..8160fe3ed057ad7524ce1191e389e08558a5e864
--- /dev/null
+++ b/models/smi_ted/notebooks/data/moses_test.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0248e682a9c29ca7649184dd88a4edc83f48f0e84af6b923990069c3da4501b6
+size 6490097
diff --git a/models/smi_ted/notebooks/smi_ted_encoder_decoder_example.ipynb b/models/smi_ted/notebooks/smi_ted_encoder_decoder_example.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..6313b3abfbc282d941f07fd01d6e65582d2585e5
--- /dev/null
+++ b/models/smi_ted/notebooks/smi_ted_encoder_decoder_example.ipynb
@@ -0,0 +1,334 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# granite.materials.smi-TED - Encoder & Decoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append('../inference')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# materials.smi-ted (smi-ted)\n",
+ "from smi_ted_light.load import load_smi_ted\n",
+ "\n",
+ "# Data\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "\n",
+ "# Chemistry\n",
+ "from rdkit import Chem\n",
+ "from rdkit.Chem import PandasTools\n",
+ "from rdkit.Chem import Descriptors\n",
+ "from rdkit.Chem import AllChem\n",
+ "from rdkit.DataStructs import FingerprintSimilarity\n",
+ "from rdkit.DataStructs import TanimotoSimilarity"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# function to canonicalize SMILES\n",
+ "def normalize_smiles(smi, canonical=True, isomeric=False):\n",
+ " try:\n",
+ " normalized = Chem.MolToSmiles(\n",
+ " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n",
+ " )\n",
+ " except:\n",
+ " normalized = None\n",
+ " return normalized\n",
+ "\n",
+ "# function to calculate pairwise Tanimoto similarity\n",
+ "def calculate_tanimoto_similarities(fps1, fps2):\n",
+ " similarities = []\n",
+ " for i in range(len(fps1)):\n",
+ " sim = TanimotoSimilarity(fps1[i], fps2[i])\n",
+ " similarities.append(sim)\n",
+ " return similarities"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load smi-ted"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Random Seed: 12345\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Vocab size: 2393\n",
+ "[INFERENCE MODE - smi-ted-Light]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_smi_ted = load_smi_ted(\n",
+ " folder='../inference/smi_ted_light',\n",
+ " ckpt_filename='smi-ted-Light_40.pt'\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_moses = pd.read_csv(\"./data/moses_test.csv\", nrows=1000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1000, 1)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " SMILES \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " Clc1ccccc1-c1nc(-c2ccncc2)no1 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " SMILES\n",
+ "0 CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1\n",
+ "1 COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O\n",
+ "2 CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2\n",
+ "3 Clc1ccccc1-c1nc(-c2ccncc2)no1\n",
+ "4 CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_moses['SMILES'] = df_moses['SMILES'].apply(normalize_smiles)\n",
+ "df_test_normalized = df_moses.dropna()\n",
+ "print(df_test_normalized.shape)\n",
+ "df_test_normalized.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Encode SMILES - smi-ted"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 10/10 [00:06<00:00, 1.52it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " encode_embeddings = model_smi_ted.encode(df_moses['SMILES'], return_torch=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Decode smi-ted embeddings into SMILES"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " decoded_smiles = model_smi_ted.decode(encode_embeddings)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1',\n",
+ " 'COc1ccc(-c2cc(=O)c3c(O)c(OC)c(OC)cc3o2)cc1O',\n",
+ " 'CCOC(=O)c1ncn2c1CN(C)C(=O)c1cc(F)ccc1-2',\n",
+ " 'Clc1ccccc1-c1nc(-c2ccncc2)no1',\n",
+ " 'CC(C)(Oc1ccc(Cl)cc1)C(=O)OCc1cccc(CO)n1']"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "decoded_smiles[0:5]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Compare similarities"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mean Tanimoto Similarity: 1.00\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Convert SMILES to RDKit molecule objects\n",
+ "mols1 = [Chem.MolFromSmiles(smiles) for smiles in df_moses['SMILES'].to_list()]\n",
+ "mols2 = [Chem.MolFromSmiles(smiles) for smiles in decoded_smiles]\n",
+ "\n",
+ "# Compute fingerprints for each molecule\n",
+ "fps1 = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols1]\n",
+ "fps2 = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols2]\n",
+ "\n",
+ "# Calculate Tanimoto similarities\n",
+ "tanimoto_similarities = calculate_tanimoto_similarities(fps1, fps2)\n",
+ "\n",
+ "# Calculate the mean similarity\n",
+ "mean_similarity = np.mean(tanimoto_similarities)\n",
+ "\n",
+ "# Print the mean similarity\n",
+ "print(f\"Mean Tanimoto Similarity: {mean_similarity:.2f}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/models/smi_ted/notebooks/smi_ted_frozen_inference_example1.ipynb b/models/smi_ted/notebooks/smi_ted_frozen_inference_example1.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..5ab26dd0c948661119caeef0ab6437384d74151a
--- /dev/null
+++ b/models/smi_ted/notebooks/smi_ted_frozen_inference_example1.ipynb
@@ -0,0 +1,1412 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# granite.materials.smi-TED - INFERENCE (Classification)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install extra packages for notebook\n",
+ "%pip install seaborn xgboost"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append('../inference')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# materials.smi-ted\n",
+ "from smi_ted_light.load import load_smi_ted\n",
+ "\n",
+ "# Data\n",
+ "import torch\n",
+ "import pandas as pd\n",
+ "\n",
+ "# Chemistry\n",
+ "from rdkit import Chem\n",
+ "from rdkit.Chem import PandasTools\n",
+ "from rdkit.Chem import Descriptors\n",
+ "PandasTools.RenderImagesInAllDataFrames(True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# function to canonicalize SMILES\n",
+ "def normalize_smiles(smi, canonical=True, isomeric=False):\n",
+ " try:\n",
+ " normalized = Chem.MolToSmiles(\n",
+ " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n",
+ " )\n",
+ " except:\n",
+ " normalized = None\n",
+ " return normalized"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Import smi-ted"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Random Seed: 12345\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Vocab size: 2393\n",
+ "[INFERENCE MODE - smi-ted-Light]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_smi_ted = load_smi_ted(\n",
+ " folder='../inference/smi_ted_light',\n",
+ " ckpt_filename='smi-ted-Light_40.pt'\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## BBBP Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Experiments - Data Load"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_train = pd.read_csv(\"../finetune/moleculenet/bbbp/train.csv\")\n",
+ "df_test = pd.read_csv(\"../finetune/moleculenet/bbbp/test.csv\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### SMILES canonization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1634, 5)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[22:56:14] Explicit valence for atom # 1 N, 4, is greater than permitted\n",
+ "[22:56:14] Explicit valence for atom # 6 N, 4, is greater than permitted\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] Explicit valence for atom # 6 N, 4, is greater than permitted\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] Explicit valence for atom # 11 N, 4, is greater than permitted\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] Explicit valence for atom # 5 N, 4, is greater than permitted\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:14] WARNING: not removing hydrogen atom without neighbors\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " num \n",
+ " name \n",
+ " p_np \n",
+ " smiles \n",
+ " norm_smiles \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " Propanolol \n",
+ " 1 \n",
+ " [Cl].CC(C)NCC(O)COc1cccc2ccccc12 \n",
+ " CC(C)NCC(O)COc1cccc2ccccc12.[Cl] \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 2 \n",
+ " Terbutylchlorambucil \n",
+ " 1 \n",
+ " C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl \n",
+ " CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 3 \n",
+ " 40730 \n",
+ " 1 \n",
+ " c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO... \n",
+ " CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 4 \n",
+ " 24 \n",
+ " 1 \n",
+ " C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C \n",
+ " CC(=O)NCCCOc1cccc(CN2CCCCC2)c1 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 6 \n",
+ " cefoperazone \n",
+ " 1 \n",
+ " CCN1CCN(C(=O)N[C@@H](C(=O)N[C@H]2[C@H]3SCC(=C(... \n",
+ " CCN1CCN(C(=O)NC(C(=O)NC2C(=O)N3C(C(=O)O)=C(CSc... \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num name p_np \\\n",
+ "0 1 Propanolol 1 \n",
+ "1 2 Terbutylchlorambucil 1 \n",
+ "2 3 40730 1 \n",
+ "3 4 24 1 \n",
+ "4 6 cefoperazone 1 \n",
+ "\n",
+ " smiles \\\n",
+ "0 [Cl].CC(C)NCC(O)COc1cccc2ccccc12 \n",
+ "1 C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl \n",
+ "2 c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO... \n",
+ "3 C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C \n",
+ "4 CCN1CCN(C(=O)N[C@@H](C(=O)N[C@H]2[C@H]3SCC(=C(... \n",
+ "\n",
+ " norm_smiles \n",
+ "0 CC(C)NCC(O)COc1cccc2ccccc12.[Cl] \n",
+ "1 CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1 \n",
+ "2 CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23 \n",
+ "3 CC(=O)NCCCOc1cccc(CN2CCCCC2)c1 \n",
+ "4 CCN1CCN(C(=O)NC(C(=O)NC2C(=O)N3C(C(=O)O)=C(CSc... "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n",
+ "df_train_normalized = df_train.dropna()\n",
+ "print(df_train_normalized.shape)\n",
+ "df_train_normalized.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(192, 5)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[22:56:17] Explicit valence for atom # 12 N, 4, is greater than permitted\n",
+ "[22:56:17] Explicit valence for atom # 5 N, 4, is greater than permitted\n",
+ "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n",
+ "[22:56:17] WARNING: not removing hydrogen atom without neighbors\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " num \n",
+ " name \n",
+ " p_np \n",
+ " smiles \n",
+ " norm_smiles \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 13 \n",
+ " 18 \n",
+ " 1 \n",
+ " C(Cl)Cl \n",
+ " ClCCl \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 23 \n",
+ " SKF-93619 \n",
+ " 0 \n",
+ " c1cc2c(cc(CC3=CNC(=NC3=O)NCCSCc3oc(cc3)CN(C)C)... \n",
+ " CN(C)Cc1ccc(CSCCNc2nc(=O)c(Cc3ccc4ccccc4c3)c[n... \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 36 \n",
+ " etomidate \n",
+ " 1 \n",
+ " CCOC(=O)c1cncn1C(C)c2ccccc2 \n",
+ " CCOC(=O)c1cncn1C(C)c1ccccc1 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 37 \n",
+ " 11a \n",
+ " 0 \n",
+ " CN(C)c1cc(C2=NC(N)=NN2)ccn1 \n",
+ " CN(C)c1cc(-c2nc(N)n[nH]2)ccn1 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 79 \n",
+ " compound 45 \n",
+ " 1 \n",
+ " N1(Cc2cc(OCCCNc3oc4ccccc4n3)ccc2)CCCCC1 \n",
+ " c1cc(CN2CCCCC2)cc(OCCCNc2nc3ccccc3o2)c1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " num name p_np smiles \\\n",
+ "0 13 18 1 C(Cl)Cl \n",
+ "1 23 SKF-93619 0 c1cc2c(cc(CC3=CNC(=NC3=O)NCCSCc3oc(cc3)CN(C)C)... \n",
+ "2 36 etomidate 1 CCOC(=O)c1cncn1C(C)c2ccccc2 \n",
+ "3 37 11a 0 CN(C)c1cc(C2=NC(N)=NN2)ccn1 \n",
+ "4 79 compound 45 1 N1(Cc2cc(OCCCNc3oc4ccccc4n3)ccc2)CCCCC1 \n",
+ "\n",
+ " norm_smiles \n",
+ "0 ClCCl \n",
+ "1 CN(C)Cc1ccc(CSCCNc2nc(=O)c(Cc3ccc4ccccc4c3)c[n... \n",
+ "2 CCOC(=O)c1cncn1C(C)c1ccccc1 \n",
+ "3 CN(C)c1cc(-c2nc(N)n[nH]2)ccn1 \n",
+ "4 c1cc(CN2CCCCC2)cc(OCCCNc2nc3ccccc3o2)c1 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n",
+ "df_test_normalized = df_test.dropna()\n",
+ "print(df_test_normalized.shape)\n",
+ "df_test_normalized.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Embeddings extraction "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### smi-ted embeddings extraction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 16/16 [00:21<00:00, 1.35s/it]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " 2 \n",
+ " 3 \n",
+ " 4 \n",
+ " 5 \n",
+ " 6 \n",
+ " 7 \n",
+ " 8 \n",
+ " 9 \n",
+ " ... \n",
+ " 758 \n",
+ " 759 \n",
+ " 760 \n",
+ " 761 \n",
+ " 762 \n",
+ " 763 \n",
+ " 764 \n",
+ " 765 \n",
+ " 766 \n",
+ " 767 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.437218 \n",
+ " -0.591727 \n",
+ " 0.064328 \n",
+ " 0.374019 \n",
+ " 0.530676 \n",
+ " -0.644067 \n",
+ " 1.308136 \n",
+ " 0.089772 \n",
+ " 0.790524 \n",
+ " 0.208749 \n",
+ " ... \n",
+ " -1.325162 \n",
+ " -0.083578 \n",
+ " 0.169544 \n",
+ " 0.359247 \n",
+ " -0.652742 \n",
+ " 0.720496 \n",
+ " -0.674184 \n",
+ " 0.693000 \n",
+ " 0.586143 \n",
+ " -0.159641 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 0.344508 \n",
+ " -0.417009 \n",
+ " 0.095745 \n",
+ " 0.355959 \n",
+ " 0.573049 \n",
+ " -0.590279 \n",
+ " 1.069699 \n",
+ " 0.067724 \n",
+ " 0.788815 \n",
+ " 0.159197 \n",
+ " ... \n",
+ " -1.312421 \n",
+ " -0.108732 \n",
+ " 0.217020 \n",
+ " 0.303697 \n",
+ " -0.598966 \n",
+ " 0.647903 \n",
+ " -0.665967 \n",
+ " 0.791804 \n",
+ " 0.620691 \n",
+ " -0.107859 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 0.429205 \n",
+ " -0.463542 \n",
+ " 0.056441 \n",
+ " 0.449925 \n",
+ " 0.536788 \n",
+ " -0.749906 \n",
+ " 1.193816 \n",
+ " 0.082596 \n",
+ " 0.860276 \n",
+ " 0.162548 \n",
+ " ... \n",
+ " -1.304979 \n",
+ " -0.148620 \n",
+ " 0.242045 \n",
+ " 0.344730 \n",
+ " -0.704636 \n",
+ " 0.644773 \n",
+ " -0.781017 \n",
+ " 0.737207 \n",
+ " 0.585380 \n",
+ " -0.101722 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 0.433097 \n",
+ " -0.523078 \n",
+ " 0.089728 \n",
+ " 0.410127 \n",
+ " 0.543400 \n",
+ " -0.643014 \n",
+ " 1.203858 \n",
+ " 0.034177 \n",
+ " 0.769413 \n",
+ " 0.202445 \n",
+ " ... \n",
+ " -1.358915 \n",
+ " -0.077463 \n",
+ " 0.228710 \n",
+ " 0.317884 \n",
+ " -0.680220 \n",
+ " 0.531601 \n",
+ " -0.709799 \n",
+ " 0.731386 \n",
+ " 0.567806 \n",
+ " -0.087713 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 0.388423 \n",
+ " -0.505908 \n",
+ " 0.072539 \n",
+ " 0.366502 \n",
+ " 0.533689 \n",
+ " -0.701559 \n",
+ " 1.035554 \n",
+ " 0.038419 \n",
+ " 0.822917 \n",
+ " 0.163062 \n",
+ " ... \n",
+ " -1.271012 \n",
+ " -0.176412 \n",
+ " 0.119734 \n",
+ " 0.294143 \n",
+ " -0.677721 \n",
+ " 0.647655 \n",
+ " -0.844419 \n",
+ " 0.756321 \n",
+ " 0.570513 \n",
+ " -0.240003 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
5 rows × 768 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 \\\n",
+ "0 0.437218 -0.591727 0.064328 0.374019 0.530676 -0.644067 1.308136 \n",
+ "1 0.344508 -0.417009 0.095745 0.355959 0.573049 -0.590279 1.069699 \n",
+ "2 0.429205 -0.463542 0.056441 0.449925 0.536788 -0.749906 1.193816 \n",
+ "3 0.433097 -0.523078 0.089728 0.410127 0.543400 -0.643014 1.203858 \n",
+ "4 0.388423 -0.505908 0.072539 0.366502 0.533689 -0.701559 1.035554 \n",
+ "\n",
+ " 7 8 9 ... 758 759 760 761 \\\n",
+ "0 0.089772 0.790524 0.208749 ... -1.325162 -0.083578 0.169544 0.359247 \n",
+ "1 0.067724 0.788815 0.159197 ... -1.312421 -0.108732 0.217020 0.303697 \n",
+ "2 0.082596 0.860276 0.162548 ... -1.304979 -0.148620 0.242045 0.344730 \n",
+ "3 0.034177 0.769413 0.202445 ... -1.358915 -0.077463 0.228710 0.317884 \n",
+ "4 0.038419 0.822917 0.163062 ... -1.271012 -0.176412 0.119734 0.294143 \n",
+ "\n",
+ " 762 763 764 765 766 767 \n",
+ "0 -0.652742 0.720496 -0.674184 0.693000 0.586143 -0.159641 \n",
+ "1 -0.598966 0.647903 -0.665967 0.791804 0.620691 -0.107859 \n",
+ "2 -0.704636 0.644773 -0.781017 0.737207 0.585380 -0.101722 \n",
+ "3 -0.680220 0.531601 -0.709799 0.731386 0.567806 -0.087713 \n",
+ "4 -0.677721 0.647655 -0.844419 0.756321 0.570513 -0.240003 \n",
+ "\n",
+ "[5 rows x 768 columns]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n",
+ "df_embeddings_train.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 1/1 [00:04<00:00, 4.23s/it]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " 2 \n",
+ " 3 \n",
+ " 4 \n",
+ " 5 \n",
+ " 6 \n",
+ " 7 \n",
+ " 8 \n",
+ " 9 \n",
+ " ... \n",
+ " 758 \n",
+ " 759 \n",
+ " 760 \n",
+ " 761 \n",
+ " 762 \n",
+ " 763 \n",
+ " 764 \n",
+ " 765 \n",
+ " 766 \n",
+ " 767 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.374249 \n",
+ " -0.319257 \n",
+ " -0.007041 \n",
+ " 0.444741 \n",
+ " 0.326734 \n",
+ " -0.791476 \n",
+ " 1.121707 \n",
+ " -0.082401 \n",
+ " 0.611457 \n",
+ " 0.289225 \n",
+ " ... \n",
+ " -1.462539 \n",
+ " -0.302055 \n",
+ " 0.295551 \n",
+ " -0.058293 \n",
+ " -0.830319 \n",
+ " 0.545099 \n",
+ " -0.460271 \n",
+ " 1.121117 \n",
+ " 0.685016 \n",
+ " -0.452698 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 0.429158 \n",
+ " -0.568104 \n",
+ " 0.112739 \n",
+ " 0.352429 \n",
+ " 0.512565 \n",
+ " -0.604153 \n",
+ " 1.181846 \n",
+ " 0.067963 \n",
+ " 0.786978 \n",
+ " 0.128077 \n",
+ " ... \n",
+ " -1.226941 \n",
+ " -0.078927 \n",
+ " 0.209468 \n",
+ " 0.266113 \n",
+ " -0.762261 \n",
+ " 0.610685 \n",
+ " -0.755705 \n",
+ " 0.734550 \n",
+ " 0.592976 \n",
+ " -0.148252 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 0.411906 \n",
+ " -0.510477 \n",
+ " 0.073015 \n",
+ " 0.346871 \n",
+ " 0.512772 \n",
+ " -0.617252 \n",
+ " 1.191621 \n",
+ " 0.040103 \n",
+ " 0.722577 \n",
+ " 0.188638 \n",
+ " ... \n",
+ " -1.300554 \n",
+ " -0.150735 \n",
+ " 0.148252 \n",
+ " 0.282791 \n",
+ " -0.694712 \n",
+ " 0.556029 \n",
+ " -0.660645 \n",
+ " 0.771226 \n",
+ " 0.558996 \n",
+ " -0.000660 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 0.356793 \n",
+ " -0.530959 \n",
+ " 0.050350 \n",
+ " 0.433593 \n",
+ " 0.592601 \n",
+ " -0.573508 \n",
+ " 1.221865 \n",
+ " 0.025491 \n",
+ " 0.833164 \n",
+ " 0.214604 \n",
+ " ... \n",
+ " -1.406141 \n",
+ " -0.107165 \n",
+ " 0.200131 \n",
+ " 0.289469 \n",
+ " -0.770149 \n",
+ " 0.572746 \n",
+ " -0.776739 \n",
+ " 0.855064 \n",
+ " 0.662797 \n",
+ " -0.194417 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 0.422133 \n",
+ " -0.490610 \n",
+ " 0.044333 \n",
+ " 0.367861 \n",
+ " 0.579025 \n",
+ " -0.629409 \n",
+ " 1.139824 \n",
+ " 0.039823 \n",
+ " 0.728825 \n",
+ " 0.145327 \n",
+ " ... \n",
+ " -1.312777 \n",
+ " -0.105049 \n",
+ " 0.175286 \n",
+ " 0.336176 \n",
+ " -0.738813 \n",
+ " 0.530226 \n",
+ " -0.763357 \n",
+ " 0.764998 \n",
+ " 0.583681 \n",
+ " -0.109683 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
5 rows × 768 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 \\\n",
+ "0 0.374249 -0.319257 -0.007041 0.444741 0.326734 -0.791476 1.121707 \n",
+ "1 0.429158 -0.568104 0.112739 0.352429 0.512565 -0.604153 1.181846 \n",
+ "2 0.411906 -0.510477 0.073015 0.346871 0.512772 -0.617252 1.191621 \n",
+ "3 0.356793 -0.530959 0.050350 0.433593 0.592601 -0.573508 1.221865 \n",
+ "4 0.422133 -0.490610 0.044333 0.367861 0.579025 -0.629409 1.139824 \n",
+ "\n",
+ " 7 8 9 ... 758 759 760 761 \\\n",
+ "0 -0.082401 0.611457 0.289225 ... -1.462539 -0.302055 0.295551 -0.058293 \n",
+ "1 0.067963 0.786978 0.128077 ... -1.226941 -0.078927 0.209468 0.266113 \n",
+ "2 0.040103 0.722577 0.188638 ... -1.300554 -0.150735 0.148252 0.282791 \n",
+ "3 0.025491 0.833164 0.214604 ... -1.406141 -0.107165 0.200131 0.289469 \n",
+ "4 0.039823 0.728825 0.145327 ... -1.312777 -0.105049 0.175286 0.336176 \n",
+ "\n",
+ " 762 763 764 765 766 767 \n",
+ "0 -0.830319 0.545099 -0.460271 1.121117 0.685016 -0.452698 \n",
+ "1 -0.762261 0.610685 -0.755705 0.734550 0.592976 -0.148252 \n",
+ "2 -0.694712 0.556029 -0.660645 0.771226 0.558996 -0.000660 \n",
+ "3 -0.770149 0.572746 -0.776739 0.855064 0.662797 -0.194417 \n",
+ "4 -0.738813 0.530226 -0.763357 0.764998 0.583681 -0.109683 \n",
+ "\n",
+ "[5 rows x 768 columns]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n",
+ "df_embeddings_test.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Experiments - BBBP prediction using smi-ted latent spaces"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### XGBoost prediction using the whole Latent Space"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from xgboost import XGBClassifier\n",
+ "from sklearn.metrics import roc_auc_score"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
+ " colsample_bylevel=None, colsample_bynode=None,\n",
+ " colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
+ " gamma=None, grow_policy=None, importance_type=None,\n",
+ " interaction_constraints=None, learning_rate=0.04, max_bin=None,\n",
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
+ " max_delta_step=None, max_depth=8, max_leaves=None,\n",
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+ " multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
+ " num_parallel_tree=None, random_state=None, ...) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. XGBClassifieriFitted XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
+ " colsample_bylevel=None, colsample_bynode=None,\n",
+ " colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
+ " gamma=None, grow_policy=None, importance_type=None,\n",
+ " interaction_constraints=None, learning_rate=0.04, max_bin=None,\n",
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
+ " max_delta_step=None, max_depth=8, max_leaves=None,\n",
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+ " multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
+ " num_parallel_tree=None, random_state=None, ...) "
+ ],
+ "text/plain": [
+ "XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
+ " colsample_bylevel=None, colsample_bynode=None,\n",
+ " colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
+ " gamma=None, grow_policy=None, importance_type=None,\n",
+ " interaction_constraints=None, learning_rate=0.04, max_bin=None,\n",
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
+ " max_delta_step=None, max_depth=8, max_leaves=None,\n",
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+ " multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
+ " num_parallel_tree=None, random_state=None, ...)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xgb_predict = XGBClassifier(n_estimators=2000, learning_rate=0.04, max_depth=8)\n",
+ "xgb_predict.fit(df_embeddings_train, df_train_normalized['p_np'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# get XGBoost predictions\n",
+ "y_prob = xgb_predict.predict_proba(df_embeddings_test)[:, 1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "ROC-AUC Score: 0.9194\n"
+ ]
+ }
+ ],
+ "source": [
+ "roc_auc = roc_auc_score(df_test_normalized[\"p_np\"], y_prob)\n",
+ "print(f\"ROC-AUC Score: {roc_auc:.4f}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/models/smi_ted/notebooks/smi_ted_frozen_inference_example2.ipynb b/models/smi_ted/notebooks/smi_ted_frozen_inference_example2.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..045044361900ecf0876c5e05b51f023d733197f3
--- /dev/null
+++ b/models/smi_ted/notebooks/smi_ted_frozen_inference_example2.ipynb
@@ -0,0 +1,1327 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# materials.smi-TED - INFERENCE (Regression)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install extra packages for notebook\n",
+ "%pip install seaborn xgboost"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append('../inference')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# materials.smi-ted (smi-ted)\n",
+ "from smi_ted_light.load import load_smi_ted\n",
+ "\n",
+ "# Data\n",
+ "import torch\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "# Chemistry\n",
+ "from rdkit import Chem\n",
+ "from rdkit.Chem import PandasTools\n",
+ "from rdkit.Chem import Descriptors\n",
+ "PandasTools.RenderImagesInAllDataFrames(True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# function to canonicalize SMILES\n",
+ "def normalize_smiles(smi, canonical=True, isomeric=False):\n",
+ " try:\n",
+ " normalized = Chem.MolToSmiles(\n",
+ " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n",
+ " )\n",
+ " except:\n",
+ " normalized = None\n",
+ " return normalized"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Import smi-ted"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Random Seed: 12345\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Using Rotation Embedding\n",
+ "Vocab size: 2393\n",
+ "[INFERENCE MODE - smi-ted-Light]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_smi_ted = load_smi_ted(\n",
+ " folder='../inference/smi_ted_light',\n",
+ " ckpt_filename='smi-ted-Light_40.pt'\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Lipophilicity Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Experiments - Data Load"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_train = pd.read_csv(\"../finetune/moleculenet/lipophilicity/train.csv\")\n",
+ "df_test = pd.read_csv(\"../finetune/moleculenet/lipophilicity/test.csv\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### SMILES canonization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(3360, 3)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " smiles \n",
+ " y \n",
+ " norm_smiles \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4 \n",
+ " 0.814313 \n",
+ " Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12... \n",
+ " 0.446346 \n",
+ " COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3... \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5... \n",
+ " 1.148828 \n",
+ " CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c... \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13 \n",
+ " 0.404532 \n",
+ " O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... \n",
+ " -0.164144 \n",
+ " O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)... \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " smiles y \\\n",
+ "0 Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4 0.814313 \n",
+ "1 COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12... 0.446346 \n",
+ "2 CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5... 1.148828 \n",
+ "3 Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13 0.404532 \n",
+ "4 Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... -0.164144 \n",
+ "\n",
+ " norm_smiles \n",
+ "0 Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1 \n",
+ "1 COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3... \n",
+ "2 CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c... \n",
+ "3 O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12 \n",
+ "4 O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)... "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n",
+ "df_train_normalized = df_train.dropna()\n",
+ "print(df_train_normalized.shape)\n",
+ "df_train_normalized.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(420, 3)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " smiles \n",
+ " y \n",
+ " norm_smiles \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " N(c1ccccc1)c2ccnc3ccccc23 \n",
+ " 0.488161 \n",
+ " c1ccc(Nc2ccnc3ccccc23)cc1 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1 \n",
+ " 0.070017 \n",
+ " Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 \n",
+ " -0.415030 \n",
+ " NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc... \n",
+ " 0.897942 \n",
+ " O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[... \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " NS(=O)(=O)c1nc2ccccc2s1 \n",
+ " -0.707731 \n",
+ " NS(=O)(=O)c1nc2ccccc2s1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " smiles y \\\n",
+ "0 N(c1ccccc1)c2ccnc3ccccc23 0.488161 \n",
+ "1 Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1 0.070017 \n",
+ "2 NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 -0.415030 \n",
+ "3 OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc... 0.897942 \n",
+ "4 NS(=O)(=O)c1nc2ccccc2s1 -0.707731 \n",
+ "\n",
+ " norm_smiles \n",
+ "0 c1ccc(Nc2ccnc3ccccc23)cc1 \n",
+ "1 Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2 \n",
+ "2 NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1 \n",
+ "3 O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[... \n",
+ "4 NS(=O)(=O)c1nc2ccccc2s1 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n",
+ "df_test_normalized = df_test.dropna()\n",
+ "print(df_test_normalized.shape)\n",
+ "df_test_normalized.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Embeddings extraction "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### smi-ted embeddings extraction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 33/33 [00:38<00:00, 1.15s/it]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " 2 \n",
+ " 3 \n",
+ " 4 \n",
+ " 5 \n",
+ " 6 \n",
+ " 7 \n",
+ " 8 \n",
+ " 9 \n",
+ " ... \n",
+ " 758 \n",
+ " 759 \n",
+ " 760 \n",
+ " 761 \n",
+ " 762 \n",
+ " 763 \n",
+ " 764 \n",
+ " 765 \n",
+ " 766 \n",
+ " 767 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.367646 \n",
+ " -0.504889 \n",
+ " 0.040485 \n",
+ " 0.385314 \n",
+ " 0.564923 \n",
+ " -0.684497 \n",
+ " 1.160397 \n",
+ " 0.071218 \n",
+ " 0.799428 \n",
+ " 0.181323 \n",
+ " ... \n",
+ " -1.379994 \n",
+ " -0.167221 \n",
+ " 0.104886 \n",
+ " 0.239571 \n",
+ " -0.744390 \n",
+ " 0.590423 \n",
+ " -0.808946 \n",
+ " 0.792584 \n",
+ " 0.550898 \n",
+ " -0.176831 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 0.455316 \n",
+ " -0.485554 \n",
+ " 0.062206 \n",
+ " 0.387994 \n",
+ " 0.567590 \n",
+ " -0.713285 \n",
+ " 1.144267 \n",
+ " -0.057046 \n",
+ " 0.753016 \n",
+ " 0.112180 \n",
+ " ... \n",
+ " -1.332142 \n",
+ " -0.096662 \n",
+ " 0.221944 \n",
+ " 0.327923 \n",
+ " -0.739358 \n",
+ " 0.659803 \n",
+ " -0.775723 \n",
+ " 0.745837 \n",
+ " 0.566330 \n",
+ " -0.111946 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 0.442309 \n",
+ " -0.484732 \n",
+ " 0.084945 \n",
+ " 0.384787 \n",
+ " 0.564752 \n",
+ " -0.704130 \n",
+ " 1.159491 \n",
+ " 0.021168 \n",
+ " 0.846539 \n",
+ " 0.118463 \n",
+ " ... \n",
+ " -1.324177 \n",
+ " -0.110403 \n",
+ " 0.207824 \n",
+ " 0.281665 \n",
+ " -0.780818 \n",
+ " 0.693484 \n",
+ " -0.832626 \n",
+ " 0.763095 \n",
+ " 0.532460 \n",
+ " -0.196708 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 0.527961 \n",
+ " -0.519151 \n",
+ " 0.091635 \n",
+ " 0.353518 \n",
+ " 0.421795 \n",
+ " -0.724220 \n",
+ " 1.093752 \n",
+ " 0.148574 \n",
+ " 0.804047 \n",
+ " 0.194627 \n",
+ " ... \n",
+ " -1.358414 \n",
+ " -0.111483 \n",
+ " 0.151692 \n",
+ " 0.186741 \n",
+ " -0.601867 \n",
+ " 0.641591 \n",
+ " -0.747422 \n",
+ " 0.794239 \n",
+ " 0.640765 \n",
+ " -0.239649 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 0.464432 \n",
+ " -0.511090 \n",
+ " 0.038785 \n",
+ " 0.346217 \n",
+ " 0.492919 \n",
+ " -0.619387 \n",
+ " 1.048157 \n",
+ " 0.095910 \n",
+ " 0.738604 \n",
+ " 0.119270 \n",
+ " ... \n",
+ " -1.223927 \n",
+ " -0.109863 \n",
+ " 0.151280 \n",
+ " 0.244834 \n",
+ " -0.686610 \n",
+ " 0.759327 \n",
+ " -0.756338 \n",
+ " 0.766427 \n",
+ " 0.610454 \n",
+ " -0.197345 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
5 rows × 768 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 \\\n",
+ "0 0.367646 -0.504889 0.040485 0.385314 0.564923 -0.684497 1.160397 \n",
+ "1 0.455316 -0.485554 0.062206 0.387994 0.567590 -0.713285 1.144267 \n",
+ "2 0.442309 -0.484732 0.084945 0.384787 0.564752 -0.704130 1.159491 \n",
+ "3 0.527961 -0.519151 0.091635 0.353518 0.421795 -0.724220 1.093752 \n",
+ "4 0.464432 -0.511090 0.038785 0.346217 0.492919 -0.619387 1.048157 \n",
+ "\n",
+ " 7 8 9 ... 758 759 760 761 \\\n",
+ "0 0.071218 0.799428 0.181323 ... -1.379994 -0.167221 0.104886 0.239571 \n",
+ "1 -0.057046 0.753016 0.112180 ... -1.332142 -0.096662 0.221944 0.327923 \n",
+ "2 0.021168 0.846539 0.118463 ... -1.324177 -0.110403 0.207824 0.281665 \n",
+ "3 0.148574 0.804047 0.194627 ... -1.358414 -0.111483 0.151692 0.186741 \n",
+ "4 0.095910 0.738604 0.119270 ... -1.223927 -0.109863 0.151280 0.244834 \n",
+ "\n",
+ " 762 763 764 765 766 767 \n",
+ "0 -0.744390 0.590423 -0.808946 0.792584 0.550898 -0.176831 \n",
+ "1 -0.739358 0.659803 -0.775723 0.745837 0.566330 -0.111946 \n",
+ "2 -0.780818 0.693484 -0.832626 0.763095 0.532460 -0.196708 \n",
+ "3 -0.601867 0.641591 -0.747422 0.794239 0.640765 -0.239649 \n",
+ "4 -0.686610 0.759327 -0.756338 0.766427 0.610454 -0.197345 \n",
+ "\n",
+ "[5 rows x 768 columns]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n",
+ "df_embeddings_train.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 4/4 [00:05<00:00, 1.46s/it]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " 2 \n",
+ " 3 \n",
+ " 4 \n",
+ " 5 \n",
+ " 6 \n",
+ " 7 \n",
+ " 8 \n",
+ " 9 \n",
+ " ... \n",
+ " 758 \n",
+ " 759 \n",
+ " 760 \n",
+ " 761 \n",
+ " 762 \n",
+ " 763 \n",
+ " 764 \n",
+ " 765 \n",
+ " 766 \n",
+ " 767 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.392252 \n",
+ " -0.504846 \n",
+ " 0.056791 \n",
+ " 0.356297 \n",
+ " 0.475918 \n",
+ " -0.648899 \n",
+ " 1.157862 \n",
+ " -0.022914 \n",
+ " 0.703240 \n",
+ " 0.192023 \n",
+ " ... \n",
+ " -1.208714 \n",
+ " -0.094441 \n",
+ " 0.128845 \n",
+ " 0.403995 \n",
+ " -0.782782 \n",
+ " 0.541907 \n",
+ " -0.707272 \n",
+ " 0.901041 \n",
+ " 0.629461 \n",
+ " -0.020630 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 0.387422 \n",
+ " -0.481142 \n",
+ " 0.049675 \n",
+ " 0.353058 \n",
+ " 0.601170 \n",
+ " -0.646099 \n",
+ " 1.142392 \n",
+ " 0.060092 \n",
+ " 0.763799 \n",
+ " 0.110331 \n",
+ " ... \n",
+ " -1.248282 \n",
+ " -0.139790 \n",
+ " 0.075585 \n",
+ " 0.202242 \n",
+ " -0.729794 \n",
+ " 0.705914 \n",
+ " -0.771751 \n",
+ " 0.843173 \n",
+ " 0.618850 \n",
+ " -0.213584 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 0.390975 \n",
+ " -0.510056 \n",
+ " 0.070656 \n",
+ " 0.380695 \n",
+ " 0.601486 \n",
+ " -0.595827 \n",
+ " 1.182193 \n",
+ " 0.011085 \n",
+ " 0.688093 \n",
+ " 0.056453 \n",
+ " ... \n",
+ " -1.294595 \n",
+ " -0.164846 \n",
+ " 0.194435 \n",
+ " 0.240742 \n",
+ " -0.773443 \n",
+ " 0.608631 \n",
+ " -0.747181 \n",
+ " 0.791911 \n",
+ " 0.611874 \n",
+ " -0.125455 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 0.423924 \n",
+ " -0.557325 \n",
+ " 0.083810 \n",
+ " 0.328703 \n",
+ " 0.399589 \n",
+ " -0.622818 \n",
+ " 1.079945 \n",
+ " 0.097611 \n",
+ " 0.724030 \n",
+ " 0.135976 \n",
+ " ... \n",
+ " -1.412060 \n",
+ " -0.106541 \n",
+ " 0.153314 \n",
+ " 0.209962 \n",
+ " -0.699690 \n",
+ " 0.648061 \n",
+ " -0.716241 \n",
+ " 0.757986 \n",
+ " 0.615963 \n",
+ " -0.258693 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 0.335576 \n",
+ " -0.559591 \n",
+ " 0.119437 \n",
+ " 0.364141 \n",
+ " 0.375474 \n",
+ " -0.639833 \n",
+ " 1.144707 \n",
+ " 0.077512 \n",
+ " 0.791759 \n",
+ " 0.164201 \n",
+ " ... \n",
+ " -1.279041 \n",
+ " -0.186733 \n",
+ " 0.106963 \n",
+ " 0.254949 \n",
+ " -0.651694 \n",
+ " 0.594167 \n",
+ " -0.680426 \n",
+ " 0.887482 \n",
+ " 0.651587 \n",
+ " -0.144996 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
5 rows × 768 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 \\\n",
+ "0 0.392252 -0.504846 0.056791 0.356297 0.475918 -0.648899 1.157862 \n",
+ "1 0.387422 -0.481142 0.049675 0.353058 0.601170 -0.646099 1.142392 \n",
+ "2 0.390975 -0.510056 0.070656 0.380695 0.601486 -0.595827 1.182193 \n",
+ "3 0.423924 -0.557325 0.083810 0.328703 0.399589 -0.622818 1.079945 \n",
+ "4 0.335576 -0.559591 0.119437 0.364141 0.375474 -0.639833 1.144707 \n",
+ "\n",
+ " 7 8 9 ... 758 759 760 761 \\\n",
+ "0 -0.022914 0.703240 0.192023 ... -1.208714 -0.094441 0.128845 0.403995 \n",
+ "1 0.060092 0.763799 0.110331 ... -1.248282 -0.139790 0.075585 0.202242 \n",
+ "2 0.011085 0.688093 0.056453 ... -1.294595 -0.164846 0.194435 0.240742 \n",
+ "3 0.097611 0.724030 0.135976 ... -1.412060 -0.106541 0.153314 0.209962 \n",
+ "4 0.077512 0.791759 0.164201 ... -1.279041 -0.186733 0.106963 0.254949 \n",
+ "\n",
+ " 762 763 764 765 766 767 \n",
+ "0 -0.782782 0.541907 -0.707272 0.901041 0.629461 -0.020630 \n",
+ "1 -0.729794 0.705914 -0.771751 0.843173 0.618850 -0.213584 \n",
+ "2 -0.773443 0.608631 -0.747181 0.791911 0.611874 -0.125455 \n",
+ "3 -0.699690 0.648061 -0.716241 0.757986 0.615963 -0.258693 \n",
+ "4 -0.651694 0.594167 -0.680426 0.887482 0.651587 -0.144996 \n",
+ "\n",
+ "[5 rows x 768 columns]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n",
+ "df_embeddings_test.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Experiments - Lipophilicity prediction using smi-ted latent spaces"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### XGBoost prediction using the whole Latent Space"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from xgboost import XGBRegressor\n",
+ "from sklearn.metrics import mean_squared_error"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
+ " colsample_bylevel=None, colsample_bynode=None,\n",
+ " colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
+ " gamma=None, grow_policy=None, importance_type=None,\n",
+ " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
+ " max_delta_step=None, max_depth=4, max_leaves=None,\n",
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+ " multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
+ " num_parallel_tree=None, random_state=None, ...) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. XGBRegressoriFitted XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
+ " colsample_bylevel=None, colsample_bynode=None,\n",
+ " colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
+ " gamma=None, grow_policy=None, importance_type=None,\n",
+ " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
+ " max_delta_step=None, max_depth=4, max_leaves=None,\n",
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+ " multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
+ " num_parallel_tree=None, random_state=None, ...) "
+ ],
+ "text/plain": [
+ "XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
+ " colsample_bylevel=None, colsample_bynode=None,\n",
+ " colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
+ " gamma=None, grow_policy=None, importance_type=None,\n",
+ " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
+ " max_delta_step=None, max_depth=4, max_leaves=None,\n",
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
+ " multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
+ " num_parallel_tree=None, random_state=None, ...)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xgb_predict = XGBRegressor(n_estimators=2000, learning_rate=0.05, max_depth=4)\n",
+ "xgb_predict.fit(df_embeddings_train, df_train_normalized['y'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# get XGBoost predictions\n",
+ "y_pred = xgb_predict.predict(df_embeddings_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "RMSE Score: 0.6485\n"
+ ]
+ }
+ ],
+ "source": [
+ "rmse = np.sqrt(mean_squared_error(df_test_normalized[\"y\"], y_pred))\n",
+ "print(f\"RMSE Score: {rmse:.4f}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/models/smi_ted/paper/smi-ted_preprint.pdf b/models/smi_ted/paper/smi-ted_preprint.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..5c6c3282b163b354a9beded102db50495605dc75
--- /dev/null
+++ b/models/smi_ted/paper/smi-ted_preprint.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75b2e2dc74c2d9a87cb19d19a4770e0db353b4ade883d8cc366b1faef4ba053f
+size 3343180
diff --git a/models/smi_ted/requirements.txt b/models/smi_ted/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c8d2ad17095979cb7654b7588808a4b8015131d4
--- /dev/null
+++ b/models/smi_ted/requirements.txt
@@ -0,0 +1,9 @@
+transformers
+torch-optimizer
+datasets
+scikit-learn
+scipy>=1.12.0
+numpy==1.26.4
+pandas==1.4.0
+tqdm>=4.66.4
+rdkit>=2024.3.5
\ No newline at end of file
diff --git a/models/smi_ted/smi_ted_light/load.py b/models/smi_ted/smi_ted_light/load.py
index e10cf1524dcd7f4c9a08460fc7960f2428fd654b..adbb9239b5a3e37af7197abbea1ced068a8f99d7 100644
--- a/models/smi_ted/smi_ted_light/load.py
+++ b/models/smi_ted/smi_ted_light/load.py
@@ -19,6 +19,13 @@ from huggingface_hub import hf_hub_download
# Data
import numpy as np
import pandas as pd
+import numpy as np
+
+# Chemistry
+from rdkit import Chem
+from rdkit.Chem import PandasTools
+from rdkit.Chem import Descriptors
+PandasTools.RenderImagesInAllDataFrames(True)
# Standard library
from functools import partial
@@ -30,6 +37,17 @@ from tqdm import tqdm
tqdm.pandas()
+# function to canonicalize SMILES
+def normalize_smiles(smi, canonical=True, isomeric=False):
+ try:
+ normalized = Chem.MolToSmiles(
+ Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
+ )
+ except:
+ normalized = None
+ return normalized
+
+
class MolTranBertTokenizer(BertTokenizer):
def __init__(self, vocab_file: str = '',
do_lower_case=False,
@@ -477,9 +495,17 @@ class Smi_ted(nn.Module):
if self.is_cuda_available:
self.encoder.cuda()
self.decoder.cuda()
+
+ # handle single str or a list of str
+ smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
+
+ # SMILES normalization
+ smiles = smiles.apply(normalize_smiles)
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
+ smiles = smiles.dropna()
# tokenizer
- idx, mask = self.tokenize(smiles)
+ idx, mask = self.tokenize(smiles.to_list())
###########
# Encoder #
@@ -515,6 +541,30 @@ class Smi_ted(nn.Module):
# reconstruct tokens
pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
pred_ids = torch.argmax(pred_ids, axis=-1)
+
+ # replacing null SMILES with NaN values
+ for idx in null_idx:
+ true_ids = true_ids.tolist()
+ pred_ids = pred_ids.tolist()
+ true_cte = true_cte.tolist()
+ pred_cte = pred_cte.tolist()
+ true_set = true_set.tolist()
+ pred_set = pred_set.tolist()
+
+ true_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
+ pred_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
+ true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
+ pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
+ true_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
+ pred_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
+
+ if len(null_idx) > 0:
+ true_ids = torch.tensor(true_ids)
+ pred_ids = torch.tensor(pred_ids)
+ true_cte = torch.tensor(true_cte)
+ pred_cte = torch.tensor(pred_cte)
+ true_set = torch.tensor(true_set)
+ pred_set = torch.tensor(pred_set)
return ((true_ids, pred_ids), # tokens
(true_cte, pred_cte), # token embeddings
@@ -548,9 +598,14 @@ class Smi_ted(nn.Module):
# handle single str or a list of str
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
- n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
-
+
+ # SMILES normalization
+ smiles = smiles.apply(normalize_smiles)
+ null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
+ smiles = smiles.dropna()
+
# process in batches
+ n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
embeddings = [
self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
for batch in tqdm(np.array_split(smiles, n_split))
@@ -562,8 +617,13 @@ class Smi_ted(nn.Module):
torch.cuda.empty_cache()
gc.collect()
+ # replacing null SMILES with NaN values
+ for idx in null_idx:
+ flat_list.insert(idx, np.array([np.nan]*self.config['n_embd']))
+ flat_list = np.asarray(flat_list)
+
if return_torch:
- return torch.tensor(np.array(flat_list))
+ return torch.tensor(flat_list)
return pd.DataFrame(flat_list)
def decode(self, smiles_embeddings):
@@ -607,6 +667,7 @@ def load_smi_ted(folder="./smi_ted_light",
):
tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
model = Smi_ted(tokenizer)
+
repo_id = "ibm/materials.smi-ted"
filename = "smi-ted-Light_40.pt"
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
diff --git a/models/smi_ted/training/args.py b/models/smi_ted/training/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..282d78326279fbfb1179c880c454bafd0df8ad84
--- /dev/null
+++ b/models/smi_ted/training/args.py
@@ -0,0 +1,254 @@
+import argparse
+
+
+def get_parser(parser=None):
+ if parser is None:
+ parser = argparse.ArgumentParser()
+
+ # Model
+ #model_arg = parser.add_argument_group('Model')
+ parser.add_argument('--n_head',
+ type=int, default=8,
+ help='GPT number of heads')
+ parser.add_argument('--n_layer',
+ type=int, default=12,
+ help='GPT number of layers')
+ parser.add_argument('--q_dropout',
+ type=float, default=0.5,
+ help='Encoder layers dropout')
+ parser.add_argument('--d_dropout',
+ type=float, default=0.1,
+ help='Decoder layers dropout')
+ parser.add_argument('--n_embd',
+ type=int, default=768,
+ help='Latent vector dimensionality')
+ parser.add_argument('--fc_h',
+ type=int, default=512,
+ help='Fully connected hidden dimensionality')
+
+
+ # Train
+ #train_arg = parser.add_argument_group('Train')
+ parser.add_argument('--n_batch',
+ type=int, default=512,
+ help='Batch size')
+ parser.add_argument('--unlike_alpha',
+ type=float, default=1.0,
+ help='unlikelihood loss alpha weight')
+ parser.add_argument('--from_scratch',
+ action='store_true', default=False,
+ help='train on qm9 from scratch')
+ parser.add_argument('--unlikelihood',
+ action='store_true', default=False,
+ help='use unlikelihood loss with gpt pretrain')
+ parser.add_argument('--grad_acc',
+ type=int, default=1,
+ help='number of batches to accumulate gradients')
+ parser.add_argument('--checkpoint_every',
+ type=int, default=1000,
+ help='save checkpoint every x iterations')
+ parser.add_argument('--clip_grad',
+ type=int, default=50,
+ help='Clip gradients to this value')
+ parser.add_argument('--lr_start',
+ type=float, default=3 * 1e-4,
+ help='Initial lr value')
+ parser.add_argument('--lr_end',
+ type=float, default=3 * 1e-4,
+ help='Maximum lr weight value')
+ parser.add_argument('--lr_multiplier',
+ type=int, default=1,
+ help='lr weight multiplier')
+ parser.add_argument('--n_last',
+ type=int, default=1000,
+ help='Number of iters to smooth loss calc')
+ parser.add_argument('--n_jobs',
+ type=int, default=1,
+ help='Number of threads')
+ parser.add_argument('--accelerator',
+ type=str, default='ddp',
+ help='The accelerator backend to use (previously known as distributed_backend)')
+ parser.add_argument('--num_nodes',
+ type=int, default=1,
+ help='number of GPU nodes for distributed training')
+ parser.add_argument('--device',
+ type=str, default='cuda',
+ help='Device to run: "cpu" or "cuda:"')
+ parser.add_argument('--seed',
+ type=int, default=12345,
+ help='Seed')
+ parser.add_argument('--init_params_from',
+ type=str, default='',
+ help='Path to a ckpt used to initialize the parameters if no restart_path is provided')
+ parser.add_argument('--train_decoder_every',
+ type=int, default=10,
+ help='Optimize decoder params every n batches')
+ parser.add_argument('--lr_decoder',
+ type=float, default=1e-4,
+ help='Learning rate for decoder part')
+ parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
+ parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
+ parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
+ parser.add_argument('--save_checkpoint_path', default='/data', help='checkpoint saving path')
+ parser.add_argument('--load_checkpoint_path', default='', help='checkpoint loading path')
+
+ #common_arg = parser.add_argument_group('Common')
+ parser.add_argument('--vocab_load',
+ type=str, required=False,
+ help='Where to load the vocab')
+ parser.add_argument('--n_samples',
+ type=int, required=False,
+ help='Number of samples to sample')
+ parser.add_argument('--gen_save',
+ type=str, required=False,
+ help='Where to save the gen molecules')
+ parser.add_argument("--max_len",
+ type=int, default=100,
+ help="Max of length of SMILES")
+ parser.add_argument('--train_load',
+ type=str, required=False,
+ help='Where to load the model')
+ parser.add_argument('--val_load',
+ type=str, required=False,
+ help='Where to load the model')
+ parser.add_argument('--n_workers',
+ type=int, required=False, default=1,
+ help='Where to load the model')
+ #beam search hyper parameters
+ parser.add_argument('--beam_size', type=int, default=0,
+ help="Number of beams to generate")
+ parser.add_argument('--num_seq_returned', type=int, default=0,
+ help="number of beams to be returned (must be <= beam_size")
+ parser.add_argument('--min_len', type=int, default=1,
+ help="minimum length to be generated")
+ parser.add_argument('--nucleus_thresh', type=float, default=.9,
+ help="nucleus sampling threshold")
+ parser.add_argument('--finetune_path',
+ type=str, default="",
+ help='path to trainer file to continue training')
+ parser.add_argument('--restart_path',
+ type=str, default="",
+ help='path to trainer file to continue training')
+ parser.add_argument('--data_path',
+ type=str, default="",
+ help='path to pubchem file')
+ parser.add_argument('--pretext_size',
+ type=int, default=0,
+ help='number of k-mers to pretext')
+ parser.add_argument('--model_save_dir',
+ type=str, required=False, default='./models_dump/',
+ help='Where to save the models/log/config/vocab')
+ parser.add_argument('--model_save',
+ type=str, required=False, default='model.pt',
+ help='Where to save the model')
+ #parser.add_argument('--save_frequency',
+ # type=int, default=20,
+ # help='How often to save the model')
+ parser.add_argument('--num_epoch',
+ type=int, default=1,
+ help='number of epochs to train')
+ #parser.add_argument('--num_iter',
+ # type=int, default=-1,
+ # help='how many itersations per epoch (for unlikelihood tuning)')
+ parser.add_argument('--log_file',
+ type=str, required=False,
+ help='Where to save the log')
+ parser.add_argument('--tb_loc',
+ type=str, required=False,
+ help='Where to save the tensorflow location')
+ parser.add_argument('--config_save',
+ type=str, required=False,
+ help='Where to save the config')
+ parser.add_argument('--vocab_save',
+ type=str,
+ help='Where to save the vocab')
+
+ # resume_arg = parser.add_argument_group('Resume')
+ parser.add_argument('--debug',
+ default=False, action='store_true',
+ help='do not erase cache at end of program')
+ parser.add_argument('--fast_dev_run',
+ default=False,
+ help='This flag runs a “unit test” by running n if set to n (int) else 1 if set to True training and validation batch(es).')
+ parser.add_argument('--freeze_model',
+ default=False, action='store_true',
+ help='freeze weights of bert model during fine tuning')
+ parser.add_argument('--resume',
+ default=False, action='store_true',
+ help='Resume from a saved model')
+ parser.add_argument('--rotate',
+ default=False, action='store_true',
+ help='use rotational relative embedding')
+ parser.add_argument('--model_load',
+ type=str, required=False,
+ help='Where to load the model')
+ parser.add_argument('--root_dir',
+ type=str, required=False, default='.',
+ help='location of root dir')
+ parser.add_argument('--config_load',
+ type=str, required=False,
+ help='Where to load the config')
+ parser.add_argument('--gpus',
+ type=int, required=False, default=1,
+ help='number of gpus to use')
+ #parser.add_argument('--start_epoch',
+ # type=int, required=False, default=0,
+ # help='Where to load the config')
+
+ parser.add_argument('--model_arch',
+ type=str, required=False,
+ help='used to teack model arch in params')
+ parser.add_argument('--eval_every',
+ type=int, default=50000,
+ help='run evaluation every x iterations')
+ parser.add_argument('--num_feats',
+ type=int, required=False, default=32,
+ help='number of random reatures for FAVOR+')
+ parser.add_argument('--max_epochs',
+ type=int, required=False, default=1,
+ help='max number of epochs')
+
+ # debug() FINE TUNEING
+ # parser.add_argument('--save_dir', type=str, required=True)
+ parser.add_argument('--mode',
+ type=str, default='cls',
+ help='type of pooling to use')
+ parser.add_argument("--dataset_length", type=int, default=None, required=False)
+ parser.add_argument("--num_workers", type=int, default=0, required=False)
+ parser.add_argument("--dropout", type=float, default=0.1, required=False)
+ #parser.add_argument("--dims", type=int, nargs="*", default="", required=False)
+ parser.add_argument(
+ "--smiles_embedding",
+ type=str,
+ default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/embeddings/protein/ba_embeddings_tanh_512_2986138_2.pt",
+ )
+ # parser.add_argument("--train_pct", type=str, required=False, default="95")
+ #parser.add_argument("--aug", type=int, required=True)
+ parser.add_argument("--dataset_name", type=str, required=False, default="sol")
+ parser.add_argument("--measure_name", type=str, required=False, default="measure")
+ parser.add_argument("--smi_ted_version", type=str, required=True, default="v1")
+ #parser.add_argument("--emb_type", type=str, required=True)
+ #parser.add_argument("--checkpoints_folder", type=str, required=True)
+ #parser.add_argument("--results_dir", type=str, required=True)
+ #parser.add_argument("--patience_epochs", type=int, required=True)
+
+ parser.add_argument(
+ "--data_root",
+ type=str,
+ required=False,
+ default="/dccstor/medscan7/smallmolecule/runs/ba-predictor/small-data/affinity",
+ )
+ # parser.add_argument("--use_bn", type=int, default=0)
+ parser.add_argument("--use_linear", type=int, default=0)
+
+ parser.add_argument("--lr", type=float, default=0.001)
+ # parser.add_argument("--weight_decay", type=float, default=5e-4)
+ # parser.add_argument("--val_check_interval", type=float, default=1.0)
+ parser.add_argument("--batch_size", type=int, default=64)
+
+ return parser
+def parse_args():
+ parser = get_parser()
+ args = parser.parse_args()
+ return args
+
diff --git a/models/smi_ted/training/bert_vocab_curated.txt b/models/smi_ted/training/bert_vocab_curated.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4de43b6c87add8821aa8d2d33d58232929c86fbd
--- /dev/null
+++ b/models/smi_ted/training/bert_vocab_curated.txt
@@ -0,0 +1,2393 @@
+
+
+
+
+C
+c
+(
+)
+1
+O
+N
+2
+=
+n
+3
+[C@H]
+[C@@H]
+F
+S
+4
+Cl
+-
+o
+s
+[nH]
+#
+/
+Br
+[C@]
+[C@@]
+[N+]
+[O-]
+5
+\
+.
+I
+6
+[S@]
+[S@@]
+P
+[N-]
+[Si]
+7
+[n+]
+[2H]
+8
+[NH+]
+B
+9
+[C-]
+[Na+]
+[Cl-]
+[c-]
+[CH]
+%10
+[NH2+]
+[P+]
+[B]
+[I-]
+%11
+[CH2-]
+[O+]
+[NH3+]
+[C]
+[Br-]
+[IH2]
+[S-]
+[cH-]
+%12
+[nH+]
+[B-]
+[K+]
+[Sn]
+[Se]
+[CH-]
+[HH]
+[Y]
+[n-]
+[CH3-]
+[SiH]
+[S+]
+%13
+[SiH2]
+[Li+]
+[NH-]
+%14
+[Na]
+[CH2]
+[O-2]
+[U+2]
+[W]
+[Al]
+[P@]
+[Fe+2]
+[PH+]
+%15
+[Cl+3]
+[Zn+2]
+[Ir]
+[Mg+2]
+[Pt+2]
+[OH2+]
+[As]
+[Fe]
+[OH+]
+[Zr+2]
+[3H]
+[Ge]
+[SiH3]
+[OH-]
+[NH4+]
+[Cu+2]
+[P@@]
+p
+[Pt]
+%16
+[Ca+2]
+[Zr]
+[F-]
+[C+]
+[Ti]
+[P-]
+[V]
+[se]
+[U]
+[O]
+[Ni+2]
+[Zn]
+[Co]
+[Ni]
+[Pd+2]
+[Cu]
+%17
+[Cu+]
+[Te]
+[H+]
+[CH+]
+[Li]
+[Pd]
+[Mo]
+[Ru+2]
+[o+]
+[Re]
+[SH+]
+%18
+[Ac]
+[Cr]
+[NH2-]
+[K]
+[13CH2]
+[c]
+[Zr+4]
+[Tl]
+[13C]
+[Mn]
+[N@+]
+[Hg]
+[Rh]
+[Ti+4]
+[Sb]
+[Co+2]
+[Ag+]
+[Ru]
+%19
+[N@@+]
+[Ti+2]
+[Al+3]
+[Pb]
+[I+]
+[18F]
+[s+]
+[Rb+]
+[Ba+2]
+[H-]
+[Fe+3]
+[Ir+3]
+[13cH]
+%20
+[AlH2]
+[Au+]
+[13c]
+[SH2+]
+[Sn+2]
+[Mn+2]
+[Si-]
+[Ag]
+[N]
+[Bi]
+%21
+[In]
+[CH2+]
+[Y+3]
+[Ga]
+%22
+[Co+3]
+[Au]
+[13CH3]
+[Mg]
+[Cs+]
+[W+2]
+[Hf]
+[Zn+]
+[Se-]
+[S-2]
+[Ca]
+[pH]
+[ClH+]
+[Ti+3]
+%23
+[Ru+]
+[SH-]
+[13CH]
+[IH+]
+[Hf+4]
+[Rf]
+[OH3+]
+%24
+[Pt+4]
+[Zr+3]
+[PH3+]
+[Sr+2]
+[Cd+2]
+[Cd]
+%25
+[Os]
+[BH-]
+[Sn+4]
+[Cr+3]
+[Ru+3]
+[PH2+]
+[Rh+2]
+[V+2]
+%26
+[Gd+3]
+[Pb+2]
+[PH]
+[Hg+]
+[Mo+2]
+[AlH]
+[Sn+]
+%27
+[Pd+]
+b
+[Rh+3]
+[Hg+2]
+[15NH]
+[14C]
+%28
+[Mn+3]
+[Si+]
+[SeH]
+[13C@H]
+[NH]
+[Ga+3]
+[SiH-]
+[13C@@H]
+[Ce]
+[Au+3]
+[Bi+3]
+[15N]
+%29
+[BH3-]
+[14cH]
+[Ti+]
+[Gd]
+[cH+]
+[Cr+2]
+[Sb-]
+%30
+[Be+2]
+[Al+]
+[te]
+[11CH3]
+[Sm]
+[Pr]
+[La]
+%31
+[Al-]
+[Ta]
+[125I]
+[BH2-]
+[Nb]
+[Si@]
+%32
+[14c]
+[Sb+3]
+[Ba]
+%33
+[Os+2]
+[Si@@]
+[La+3]
+[15n]
+[15NH2]
+[Nd+3]
+%34
+[14CH2]
+[18O]
+[Nd]
+[GeH]
+[Ni+3]
+[Eu]
+[Dy+3]
+[Sc]
+%36
+[Se-2]
+[As+]
+%35
+[AsH]
+[Tb]
+[Sb+5]
+[Se+]
+[Ce+3]
+[c+]
+[In+3]
+[SnH]
+[Mo+4]
+%37
+[V+4]
+[Eu+3]
+[Hf+2]
+%38
+[Pt+]
+[p+]
+[123I]
+[Tl+]
+[Sm+3]
+%39
+[Yb+3]
+%40
+[Yb]
+[Os+]
+%41
+[10B]
+[Sc+3]
+[Al+2]
+%42
+[Sr]
+[Tb+3]
+[Po]
+[Tc]
+[PH-]
+[AlH3]
+[Ar]
+[U+4]
+[SnH2]
+[Cl+2]
+[si]
+[Fe+]
+[14CH3]
+[U+3]
+[Cl+]
+%43
+[GeH2]
+%44
+[Er+3]
+[Mo+3]
+[I+2]
+[Fe+4]
+[99Tc]
+%45
+[11C]
+%46
+[SnH3]
+[S]
+[Te+]
+[Er]
+[Lu+3]
+[11B]
+%47
+%48
+[P]
+[Tm]
+[Th]
+[Dy]
+[Pr+3]
+[Ta+5]
+[Nb+5]
+[Rb]
+[GeH3]
+[Br+2]
+%49
+[131I]
+[Fm]
+[Cs]
+[BH4-]
+[Lu]
+[15nH]
+%50
+[Ru+6]
+[b-]
+[Ho]
+[Th+4]
+[Ru+4]
+%52
+[14CH]
+%51
+[Cr+6]
+[18OH]
+[Ho+3]
+[Ce+4]
+[Bi+2]
+[Co+]
+%53
+[Yb+2]
+[Fe+6]
+[Be]
+%54
+[SH3+]
+[Np]
+[As-]
+%55
+[14C@@H]
+[Ir+2]
+[GaH3]
+[p-]
+[GeH4]
+[Sn+3]
+[Os+4]
+%56
+[14C@H]
+[sH+]
+[19F]
+[Eu+2]
+[TlH]
+%57
+[Cr+4]
+%58
+[B@@-]
+[SiH+]
+[At]
+[Am]
+[Fe+5]
+[AsH2]
+[Si+4]
+[B@-]
+[Pu]
+[SbH]
+[P-2]
+[Tm+3]
+*
+%59
+[se+]
+[IH-]
+%60
+[oH+]
+[1H]
+[15N+]
+[124I]
+[S@@+]
+[P-3]
+[H]
+[IH2+]
+[TeH]
+[Xe]
+[PH4+]
+[Cr+]
+[Cm]
+[I+3]
+%61
+[Nb+2]
+[Ru+5]
+%62
+[Ta+2]
+[Tc+4]
+[CH3+]
+[Pm]
+[Si@H]
+[No]
+%63
+[Cr+5]
+[Th+2]
+[Zn-2]
+[13C@]
+[Lr]
+%64
+[99Tc+3]
+%65
+[13C@@]
+%66
+[Fe-]
+[17O]
+[siH]
+[Sb+]
+[OH]
+[IH]
+[11CH2]
+[Cf]
+[SiH2+]
+[Gd+2]
+[In+]
+[Si@@H]
+[Mn+]
+[99Tc+4]
+[Ga-]
+%67
+[S@+]
+[Ge+4]
+[Tl+3]
+[16OH]
+%68
+[2H-]
+[Ra]
+[si-]
+[NiH2]
+[P@@H]
+[Rh+]
+[12C]
+[35S]
+[32P]
+[SiH2-]
+[AlH2+]
+[16O]
+%69
+[BiH]
+[BiH2]
+[Zn-]
+[BH]
+[Tc+3]
+[Ir+]
+[Ni+]
+%70
+[InH2]
+[InH]
+[Nb+3]
+[PbH]
+[Bi+]
+%71
+[As+3]
+%72
+[18O-]
+[68Ga+3]
+%73
+[Pa]
+[76Br]
+[Tc+5]
+[pH+]
+[64Cu+2]
+[Ru+8]
+%74
+[PH2-]
+[Si+2]
+[17OH]
+[RuH]
+[111In+3]
+[AlH+]
+%75
+%76
+[W+]
+[SbH2]
+[PoH]
+[Ru-]
+[XeH]
+[Tc+2]
+[13C-]
+[Br+]
+[Pt-2]
+[Es]
+[Cu-]
+[Mg+]
+[3HH]
+[P@H]
+[ClH2+]
+%77
+[SH]
+[Au-]
+[2HH]
+%78
+[Sn-]
+[11CH]
+[PdH2]
+0
+[Os+6]
+%79
+[Mo+]
+%80
+[al]
+[PbH2]
+[64Cu]
+[Cl]
+[12CH3]
+%81
+[Tc+7]
+[11c]
+%82
+[Li-]
+[99Tc+5]
+[He]
+[12c]
+[Kr]
+[RuH+2]
+[35Cl]
+[Pd-2]
+[GaH2]
+[4H]
+[Sg]
+[Cu-2]
+[Br+3]
+%83
+[37Cl]
+[211At]
+[IrH+2]
+[Mt]
+[Ir-2]
+[In-]
+[12cH]
+[12CH2]
+[RuH2]
+[99Tc+7]
+%84
+[15n+]
+[ClH2+2]
+[16N]
+[111In]
+[Tc+]
+[Ru-2]
+[12CH]
+[si+]
+[Tc+6]
+%85
+%86
+[90Y]
+[Pd-]
+[188Re]
+[RuH+]
+[NiH]
+[SiH3-]
+[14n]
+[CH3]
+[14N]
+[10BH2]
+%88
+%89
+%90
+[34S]
+[77Br]
+[GaH]
+[Br]
+[Ge@]
+[B@@H-]
+[CuH]
+[SiH4]
+[3H-]
+%87
+%91
+%92
+[67Cu]
+[I]
+[177Lu]
+[ReH]
+[67Ga+3]
+[Db]
+[177Lu+3]
+[AlH2-]
+[Si+3]
+[Ti-2]
+[RuH+3]
+[al+]
+[68Ga]
+[2H+]
+[B@H-]
+[WH2]
+[OsH]
+[Ir-3]
+[AlH-]
+[Bk]
+[75Se]
+[14C@]
+[Pt-]
+[N@@H+]
+[Nb-]
+[13NH2]
+%93
+[186Re]
+[Tb+4]
+[PtH]
+[IrH2]
+[Hg-2]
+[AlH3-]
+[PdH+]
+[Md]
+[RhH+2]
+[11cH]
+[Co-2]
+[15N-]
+[ZrH2]
+%94
+[Hg-]
+[127I]
+[AsH2+]
+[MoH2]
+[Te+4]
+[14C@@]
+[As+5]
+[SnH+3]
+[Ge@@]
+[6Li+]
+[WH]
+[Ne]
+[14NH2]
+[14NH]
+[12C@@H]
+[Os+7]
+[RhH]
+[Al-3]
+[SnH+]
+[15NH3+]
+[Zr+]
+[197Hg+]
+%95
+%96
+[90Y+3]
+[Os-2]
+[98Tc+5]
+[15NH3]
+[bH-]
+[33P]
+[Zr-2]
+[15O]
+[Rh-]
+[PbH3]
+[PH2]
+[Ni-]
+[CuH+]
+%97
+%98
+%99
+[Os+5]
+[PtH+]
+[ReH4]
+[16NH]
+[82Br]
+[W-]
+[18F-]
+[15NH4+]
+[Se+4]
+[SeH-]
+[SH4]
+[67Cu+2]
+[12C@H]
+[AsH3]
+[HgH]
+[10B-]
+[99Tc+6]
+[117Sn+4]
+[Te@]
+[P@+]
+[35SH]
+[SeH+]
+[Ni-2]
+[Al-2]
+[TeH2]
+[Bh]
+[99Tc+2]
+[Os+8]
+[PH-2]
+[7Li+]
+[14nH]
+[AlH+2]
+[18FH]
+[SnH4]
+[18O-2]
+[IrH]
+[13N]
+[Te@@]
+[Rh-3]
+[15NH+]
+[AsH3+]
+[SeH2]
+[AsH+]
+[CoH2]
+[16NH2]
+[AsH-]
+[203Hg+]
+[P@@+]
+[166Ho+3]
+[60Co+3]
+[13CH2-]
+[SeH2+]
+[75Br]
+[TlH2]
+[80Br]
+[siH+]
+[Ca+]
+[153Sm+3]
+[PdH]
+[225Ac]
+[13CH3-]
+[AlH4-]
+[FeH]
+[13CH-]
+[14C-]
+[11C-]
+[153Sm]
+[Re-]
+[te+]
+[13CH4]
+[ClH+2]
+[8CH2]
+[99Mo]
+[ClH3+3]
+[SbH3]
+[25Mg+2]
+[16N+]
+[SnH2+]
+[PH4]
+[11C@H]
+[122I]
+[Re-2]
+[RuH2+2]
+[ZrH]
+[Bi-]
+[Pr+]
+[Rn]
+[Fr]
+[36Cl]
+[18o]
+[YH]
+[79Br]
+[121I]
+[113In+3]
+[InH4-]
+[TaH]
+[RhH2]
+[Ta-]
+[67Ga]
+[ZnH+]
+[SnH2-]
+[OsH2]
+[16F]
+[FeH2]
+[14O]
+[PbH2+2]
+[BH2]
+[6H]
+[125Te]
+[197Hg]
+[TaH2]
+[TaH3]
+[76As]
+[Nb-2]
+[14N+]
+[125I-]
+[33S]
+[IH2+2]
+[NH2]
+[PtH2]
+[MnH]
+[19C]
+[17F]
+[1H-]
+[SnH4+2]
+[Mn-2]
+[15NH2+]
+[TiH2]
+[ReH7]
+[Cd-2]
+[Fe-3]
+[SH2]
+[17O-]
+[siH-]
+[CoH+]
+[VH]
+[10BH]
+[Ru-3]
+[13O]
+[5H]
+[CoH]
+[PH5]
+[15n-]
+[153Gd]
+[12C@]
+[11CH3-]
+[IrH3]
+[RuH3]
+[74Se]
+[Se@]
+[Hf+]
+[77Se]
+[166Ho]
+[59Fe+2]
+[203Hg]
+[18OH-]
+[8CH]
+[12C@@]
+[11CH4]
+[15C]
+[249Cf]
+[PbH4]
+[64Zn]
+[PH3]
+[99Tc+]
+[14c-]
+[149Pm]
+[IrH4]
+[Se@@]
+[13OH]
+[14CH3-]
+[28Si]
+[Rh-2]
+[Fe-2]
+[131I-]
+[51Cr]
+[62Cu+2]
+[81Br]
+[121Sb]
+[7Li]
+[89Zr+4]
+[SbH3+]
+[11C@@H]
+[98Tc]
+[59Fe+3]
+[BiH2+]
+[SbH+]
+[TiH]
+[14NH3]
+[15OH]
+[119Sn]
+[201Hg]
+[MnH+]
+[201Tl]
+[51Cr+3]
+[123I-]
+[MoH]
+[AlH6-3]
+[MnH2]
+[WH3]
+[213Bi+3]
+[SnH2+2]
+[123IH]
+[13CH+]
+[Zr-]
+[74As]
+[13C+]
+[32P+]
+[KrH]
+[SiH+2]
+[ClH3+2]
+[13NH]
+[9CH2]
+[ZrH2+2]
+[87Sr+2]
+[35s]
+[239Pu]
+[198Au]
+[241Am]
+[203Hg+2]
+[V+]
+[YH2]
+[SH5]
+[195Pt]
+[203Pb]
+[RuH4]
+[ThH2]
+[AuH]
+[66Ga+3]
+[11B-]
+[F]
+[24Na+]
+[85Sr+2]
+[201Tl+]
+[14CH4]
+[32S]
+[TeH2+]
+[ClH2+3]
+[AgH]
+[Ge@H]
+[44Ca+2]
+[Os-]
+[31P]
+[15nH+]
+[SbH4]
+[TiH+]
+[Ba+]
+[57Co+2]
+[Ta+]
+[125IH]
+[77As]
+[129I]
+[Fe-4]
+[Ta-2]
+[19O]
+[12O]
+[BiH3]
+[237Np]
+[252Cf]
+[86Y]
+[Cr-2]
+[89Y]
+[195Pt+2]
+[si+2]
+[58Fe+2]
+[Hs]
+[S@@H]
+[OsH6]
+[GdH2]
+[IH3]
+[8CH4]
+[164Dy+3]
+[47Ca+2]
+[57Co]
+[NbH2]
+[ReH2]
+[ZnH2]
+[CrH2]
+[17NH]
+[ZrH3]
+[RhH3]
+[12C-]
+[18O+]
+[Bi-2]
+[ClH4+3]
+[Ni-3]
+[Ag-]
+[111In-]
+[Mo-2]
+[55Fe+3]
+[204Hg+]
+[35Cl-]
+[211Pb]
+[75Ge]
+[8B]
+[TeH3]
+[SnH3+]
+[Zr-3]
+[28F]
+[249Bk]
+[169Yb]
+[34SH]
+[6Li]
+[94Tc]
+[197Au]
+[195Pt+4]
+[169Yb+3]
+[32Cl]
+[82Se]
+[159Gd+3]
+[213Bi]
+[CoH+2]
+[36S]
+[35P]
+[Ru-4]
+[Cr-3]
+[60Co]
+[1H+]
+[18CH2]
+[Cd-]
+[152Sm+3]
+[106Ru]
+[238Pu]
+[220Rn]
+[45Ca+2]
+[89Sr+2]
+[239Np]
+[90Sr+2]
+[137Cs+]
+[165Dy]
+[68GaH3]
+[65Zn+2]
+[89Zr]
+[BiH2+2]
+[62Cu]
+[165Dy+3]
+[238U]
+[105Rh+3]
+[70Zn]
+[12B]
+[12OH]
+[18CH]
+[17CH]
+[OsH3]
+[SbH-]
+[SH6]
+[AlH2-2]
+[42K]
+[76Br-]
+[71As]
+[NbH3]
+[ReH3]
+[OsH-]
+[WH4]
+[MoH3]
+[OsH4]
+[RuH6]
+[PtH3]
+[CuH2]
+[CoH3]
+[TiH4]
+[64Zn+2]
+[Si-2]
+[79BrH]
+[14CH2-]
+[PtH2+2]
+[Os-3]
+[29Si]
+[Ti-]
+[Se+6]
+[22Na+]
+[42K+]
+[131Cs+]
+[86Rb+]
+[134Cs+]
+[209Po]
+[208Po]
+[81Rb+]
+[203Tl+]
+[Zr-4]
+[148Sm]
+[147Sm]
+[37Cl-]
+[12CH4]
+[Ge@@H]
+[63Cu]
+[13CH2+]
+[AsH2-]
+[CeH]
+[SnH-]
+[UH]
+[9c]
+[21CH3]
+[TeH+]
+[57Co+3]
+[8BH2]
+[12BH2]
+[19BH2]
+[9BH2]
+[YbH2]
+[CrH+2]
+[208Bi]
+[152Gd]
+[61Cu]
+[115In]
+[60Co+2]
+[13NH2-]
+[120I]
+[18OH2]
+[75SeH]
+[SbH2+]
+[144Ce]
+[16n]
+[113In]
+[22nH]
+[129I-]
+[InH3]
+[32PH3]
+[234U]
+[235U]
+[59Fe]
+[82Rb+]
+[65Zn]
+[244Cm]
+[147Pm]
+[91Y]
+[237Pu]
+[231Pa]
+[253Cf]
+[127Te]
+[187Re]
+[236Np]
+[235Np]
+[72Zn]
+[253Es]
+[159Dy]
+[62Zn]
+[101Tc]
+[149Tb]
+[124I-]
+[SeH3+]
+[210Pb]
+[40K]
+[210Po]
+[214Pb]
+[218Po]
+[214Po]
+[7Be]
+[212Pb]
+[205Pb]
+[209Pb]
+[123Te]
+[202Pb]
+[72As]
+[201Pb]
+[70As]
+[73Ge]
+[200Pb]
+[198Pb]
+[66Ga]
+[73Se]
+[195Pb]
+[199Pb]
+[144Ce+3]
+[235U+2]
+[90Tc]
+[114In+3]
+[128I]
+[100Tc+]
+[82Br-]
+[191Pt+2]
+[191Pt+4]
+[193Pt+4]
+[31PH3]
+[125I+2]
+[131I+2]
+[125Te+4]
+[82Sr+2]
+[149Sm]
+[81BrH]
+[129Xe]
+[193Pt+2]
+[123I+2]
+[Cr-]
+[Co-]
+[227Th+4]
+[249Cf+3]
+[252Cf+3]
+[187Os]
+[16O-]
+[17O+]
+[16OH-]
+[98Tc+7]
+[58Co+2]
+[69Ga+3]
+[57Fe+2]
+[43K+]
+[16C]
+[52Fe+3]
+[SeH5]
+[194Pb]
+[196Pb]
+[197Pb]
+[213Pb]
+[9B]
+[19B]
+[11CH-]
+[9CH]
+[20OH]
+[25OH]
+[8cH]
+[TiH+3]
+[SnH6+3]
+[N@H+]
+[ZnH]
+[VH3]
+[52Mn+2]
+[64Ga]
+[13B]
+[216Bi]
+[117Sn+2]
+[232Th]
+[SnH+2]
+[BiH5]
+[77Kr]
+[103Cd]
+[62Ni]
+[LaH3]
+[SmH3]
+[EuH3]
+[MoH5]
+[64Ni]
+[66Zn]
+[68Zn]
+[186W]
+[FeH4]
+[MoH4]
+[HgH2]
+[15NH2-]
+[UH2]
+[204Hg]
+[GaH4-]
+[ThH4]
+[WH6]
+[PtH4]
+[VH2]
+[UH3]
+[FeH3]
+[RuH5]
+[BiH4]
+[80Br-]
+[CeH3]
+[37ClH]
+[157Gd+3]
+[205Tl]
+[203Tl]
+[62Cu+]
+[64Cu+]
+[61Cu+]
+[37SH2]
+[30Si]
+[28Al]
+[19OH2]
+[8He]
+[6He]
+[153Pm]
+[209Bi]
+[66Zn+2]
+[10CH4]
+[191Ir]
+[66Cu]
+[16O+]
+[25O]
+[10c]
+[Co-3]
+[Sn@@]
+[17OH-]
+[206Po]
+[204Po]
+[202Po]
+[201Po]
+[200Po]
+[199Po]
+[198Po]
+[197Po]
+[196Po]
+[195Po]
+[194Po]
+[193Po]
+[192Po]
+[191Po]
+[190Po]
+[217Po]
+[BiH4-]
+[TeH4]
+[222Ra]
+[62Ga]
+[39Ar]
+[144Sm]
+[58Fe]
+[153Eu]
+[85Rb]
+[171Yb]
+[172Yb]
+[114Cd]
+[51Fe]
+[142Ce]
+[207Tl]
+[92Mo]
+[115Sn]
+[140Ce]
+[202Hg]
+[180W]
+[182W]
+[183W]
+[184W]
+[96Mo]
+[47Ti]
+[111Cd]
+[143Nd]
+[145Nd]
+[126Te]
+[128Te]
+[130Te]
+[185Re]
+[97Mo]
+[98Mo]
+[183Re]
+[52V]
+[80Se]
+[87Kr]
+[137Xe]
+[196Au]
+[146Ce]
+[88Kr]
+[51Ti]
+[138Xe]
+[112Cd]
+[116Sn]
+[120Sn]
+[28SiH3]
+[35S-]
+[15NH-]
+[13CH3+]
+[34S+]
+[34s]
+[SiH4-]
+[100Tc+5]
+[NiH2+2]
+[239Th]
+[186Lu]
+[AuH3]
+[I@@-]
+[XeH2]
+[B+]
+[16CH2]
+[8C]
+[TaH5]
+[FeH4-]
+[19C@H]
+[10NH]
+[FeH6-3]
+[22CH]
+[25N]
+[25N+]
+[25N-]
+[21CH2]
+[18cH]
+[113I]
+[ScH3]
+[30PH3]
+[43Ca+2]
+[41Ca+2]
+[106Cd]
+[122Sn]
+[18CH3]
+[58Co+3]
+[98Tc+4]
+[70Ge]
+[76Ge]
+[108Cd]
+[116Cd]
+[130Xe]
+[94Mo]
+[124Sn]
+[186Os]
+[188Os]
+[190Os]
+[192Os]
+[106Pd]
+[110Pd]
+[120Te]
+[132Ba]
+[134Ba]
+[136Ba]
+[136Ce]
+[138Ce]
+[156Dy]
+[158Dy]
+[160Dy]
+[163Dy]
+[162Er]
+[164Er]
+[167Er]
+[176Hf]
+[26Mg]
+[144Nd]
+[150Nd]
+[41K]
+[46Ti]
+[48Ti]
+[49Ti]
+[50Ti]
+[170Yb]
+[173Yb]
+[91Zr]
+[92Zr]
+[96Zr]
+[34S-]
+[CuH2-]
+[38Cl]
+[25Mg]
+[51V]
+[93Nb]
+[95Mo]
+[45Sc]
+[123Sb]
+[139La]
+[9Be]
+[99Y+3]
+[99Y]
+[156Ho]
+[67Zn]
+[144Ce+4]
+[210Tl]
+[42Ca]
+[54Fe]
+[193Ir]
+[92Nb]
+[141Cs]
+[52Cr]
+[35ClH]
+[46Ca]
+[139Cs]
+[65Cu]
+[71Ga]
+[60Ni]
+[16NH3]
+[148Nd]
+[72Ge]
+[161Dy]
+[49Ca]
+[43Ca]
+[8Be]
+[48Ca]
+[44Ca]
+[120Xe]
+[80Rb]
+[215At]
+[180Re]
+[146Sm]
+[19Ne]
+[74Kr]
+[134La]
+[76Kr]
+[219Fr]
+[121Xe]
+[220Fr]
+[216At]
+[223Ac]
+[218At]
+[37Ar]
+[135I]
+[110Cd]
+[94Tc+7]
+[86Y+3]
+[135I-]
+[15O-2]
+[151Eu+3]
+[161Tb+3]
+[197Hg+2]
+[109Cd+2]
+[191Os+4]
+[170Tm+3]
+[205Bi+3]
+[233U+4]
+[126Sb+3]
+[127Sb+3]
+[132Cs+]
+[136Eu+3]
+[136Eu]
+[125Sn+4]
+[175Yb+3]
+[100Mo]
+[22Ne]
+[13c-]
+[13NH4+]
+[17C]
+[9C]
+[31S]
+[31SH]
+[133I]
+[126I]
+[36SH]
+[30S]
+[32SH]
+[19CH2]
+[19c]
+[18c]
+[15F]
+[10C]
+[RuH-]
+[62Zn+2]
+[32ClH]
+[33ClH]
+[78BrH]
+[12Li+]
+[12Li]
+[233Ra]
+[68Ge+4]
+[44Sc+3]
+[91Y+3]
+[106Ru+3]
+[PoH2]
+[AtH]
+[55Fe]
+[233U]
+[210PoH2]
+[230Th]
+[228Th]
+[222Rn]
+[35SH2]
+[227Th]
+[192Ir]
+[133Xe]
+[81Kr]
+[95Zr]
+[240Pu]
+[54Mn]
+[103Ru]
+[95Nb]
+[109Cd]
+[141Ce]
+[85Kr]
+[110Ag]
+[58Co]
+[241Pu]
+[234Th]
+[140La]
+[63Ni]
+[152Eu]
+[132IH]
+[226Rn]
+[154Eu]
+[36ClH]
+[228Ac]
+[155Eu]
+[106Rh]
+[243Am]
+[227Ac]
+[243Cm]
+[236U]
+[144Pr]
+[232U]
+[32SH2]
+[88Y]
+[82BrH]
+[135IH]
+[242Cm]
+[115Cd]
+[242Pu]
+[46Sc]
+[56Mn]
+[234Pa]
+[41Ar]
+[147Nd]
+[187W]
+[151Sm]
+[59Ni]
+[233Pa]
+[52Mn]
+[94Nb]
+[219Rn]
+[236Pu]
+[13NH3]
+[93Zr]
+[51Cr+6]
+[TlH3]
+[123Xe]
+[160Tb]
+[170Tm]
+[182Ta]
+[175Yb]
+[93Mo]
+[143Ce]
+[191Os]
+[126IH]
+[48V]
+[113Cd]
+[47Sc]
+[181Hf]
+[185W]
+[143Pr]
+[191Pt]
+[181W]
+[33PH3]
+[97Ru]
+[97Tc]
+[111Ag]
+[169Er]
+[107Pd]
+[103Ru+2]
+[34SH2]
+[137Ce]
+[242Am]
+[117SnH2]
+[57Ni]
+[239U]
+[60Cu]
+[250Cf]
+[193Au]
+[69Zn]
+[55Co]
+[139Ce]
+[127Xe]
+[159Gd]
+[56Co]
+[177Hf]
+[244Pu]
+[38ClH]
+[142Pr]
+[199Hg]
+[179Hf]
+[178Hf]
+[237U]
+[156Eu]
+[157Eu]
+[105Ru]
+[171Tm]
+[199Au]
+[155Sm]
+[80BrH]
+[108Ag]
+[128IH]
+[48Sc]
+[45Ti]
+[176Lu]
+[121SnH2]
+[148Pm]
+[57Fe]
+[10BH3]
+[96Tc]
+[133IH]
+[143Pm]
+[105Rh]
+[130IH]
+[134IH]
+[131IH]
+[71Zn]
+[105Ag]
+[97Zr]
+[235Pu]
+[231Th]
+[109Pd]
+[93Y]
+[190Ir]
+[135Xe]
+[53Mn]
+[134Ce]
+[234Np]
+[240Am]
+[246Cf]
+[240Cm]
+[241Cm]
+[226Th]
+[39ClH]
+[229Th]
+[245Cm]
+[240U]
+[240Np]
+[249Cm]
+[243Pu]
+[145Pm]
+[199Pt]
+[246Bk]
+[193Pt]
+[230U]
+[250Cm]
+[44Ti]
+[175Hf]
+[254Fm]
+[255Fm]
+[257Fm]
+[92Y]
+[188Ir]
+[171Lu]
+[257Md]
+[247Bk]
+[121IH]
+[250Bk]
+[179Lu]
+[224Ac]
+[195Hg]
+[244Am]
+[246Pu]
+[194Au]
+[252Fm]
+[173Hf]
+[246Cm]
+[135Ce]
+[49Cr]
+[248Cf]
+[247Cm]
+[248Cm]
+[174Ta]
+[176Ta]
+[154Tb]
+[172Ta]
+[177Ta]
+[175Ta]
+[180Ta]
+[158Tb]
+[115Ag]
+[189Os]
+[251Cf]
+[145Pr]
+[147Pr]
+[76BrH]
+[102Rh]
+[238Np]
+[185Os]
+[246Am]
+[233Np]
+[166Dy]
+[254Es]
+[244Cf]
+[193Os]
+[245Am]
+[245Bk]
+[239Am]
+[238Am]
+[97Nb]
+[245Pu]
+[254Cf]
+[188W]
+[250Es]
+[251Es]
+[237Am]
+[182Hf]
+[258Md]
+[232Np]
+[238Cm]
+[60Fe]
+[109Pd+2]
+[234Pu]
+[141Ce+3]
+[136Nd]
+[136Pr]
+[173Ta]
+[110Ru]
+[147Tb]
+[253Fm]
+[139Nd]
+[178Re]
+[177Re]
+[200Au]
+[182Re]
+[156Tb]
+[155Tb]
+[157Tb]
+[161Tb]
+[161Ho]
+[167Tm]
+[173Lu]
+[179Ta]
+[171Er]
+[44Sc]
+[49Sc]
+[49V]
+[51Mn]
+[90Nb]
+[88Nb]
+[88Zr]
+[36SH2]
+[174Yb]
+[178Lu]
+[179W]
+[83BrH]
+[107Cd]
+[75BrH]
+[62Co]
+[48Cr]
+[63Zn]
+[102Ag]
+[154Sm]
+[168Er]
+[65Ni]
+[137La]
+[187Ir]
+[144Pm]
+[146Pm]
+[160Gd]
+[166Yb]
+[162Dy]
+[47V]
+[141Nd]
+[141Sm]
+[166Er]
+[150Sm]
+[146Eu]
+[149Eu]
+[174Lu]
+[17NH3]
+[102Ru]
+[170Hf]
+[188Pt]
+[61Ni]
+[56Ni]
+[149Gd]
+[151Gd]
+[141Pm]
+[147Gd]
+[146Gd]
+[161Er]
+[103Ag]
+[145Eu]
+[153Tb]
+[155Dy]
+[184Re]
+[180Os]
+[182Os]
+[186Pt]
+[181Os]
+[181Re]
+[151Tb]
+[178Ta]
+[178W]
+[189Pt]
+[194Hg]
+[145Sm]
+[150Tb]
+[132La]
+[158Gd]
+[104Ag]
+[193Hg]
+[94Ru]
+[137Pr]
+[155Ho]
+[117Cd]
+[99Ru]
+[146Nd]
+[218Rn]
+[95Y]
+[79Kr]
+[120IH]
+[138Pr]
+[100Pd]
+[166Tm]
+[90Mo]
+[151Nd]
+[231U]
+[138Nd]
+[89Nb]
+[98Nb]
+[162Ho]
+[142Sm]
+[186Ta]
+[104Tc]
+[184Ta]
+[185Ta]
+[170Er]
+[107Rh]
+[131La]
+[169Lu]
+[74BrH]
+[150Pm]
+[172Tm]
+[197Pt]
+[230Pu]
+[170Lu]
+[86Zr]
+[176W]
+[177W]
+[101Pd]
+[105Pd]
+[108Pd]
+[149Nd]
+[164Ho]
+[159Ho]
+[167Ho]
+[176Yb]
+[156Sm]
+[77BrH]
+[189Re]
+[99Rh]
+[100Rh]
+[151Pm]
+[232Pa]
+[228Pa]
+[230Pa]
+[66Ni]
+[194Os]
+[135La]
+[138La]
+[141La]
+[142La]
+[195Ir]
+[96Nb]
+[157Ho]
+[183Hf]
+[162Tm]
+[172Er]
+[148Eu]
+[150Eu]
+[15CH4]
+[89Kr]
+[143La]
+[58Ni]
+[61Co]
+[158Eu]
+[165Er]
+[167Yb]
+[173Tm]
+[175Tm]
+[172Hf]
+[172Lu]
+[93Tc]
+[177Yb]
+[124IH]
+[194Ir]
+[147Eu]
+[101Mo]
+[180Hf]
+[189Ir]
+[87Y]
+[43Sc]
+[195Au]
+[112Ag]
+[84BrH]
+[106Ag]
+[109Ag]
+[101Rh]
+[162Yb]
+[228Rn]
+[139Pr]
+[94Y]
+[201Au]
+[40PH3]
+[110Ag+]
+[104Cd]
+[133Ba+2]
+[226Ac]
+[145Gd]
+[186Ir]
+[184Ir]
+[224Rn]
+[185Ir]
+[182Ir]
+[184Hf]
+[200Pt]
+[227Pa]
+[178Yb]
+[72Br-]
+[72BrH]
+[248Am]
+[238Th]
+[161Gd]
+[35S-2]
+[107Ag]
+[FeH6-4]
+[89Sr]
+[SnH3-]
+[SeH3]
+[TeH3+]
+[SbH4+]
+[AsH4+]
+[4He]
+[AsH3-]
+[1HH]
+[3H+]
+[82Rb]
+[85Sr]
+[90Sr]
+[137Cs]
+[133Ba]
+[131Cs]
+[SbH5]
+[224Ra]
+[22Na]
+[210Bi]
+[214Bi]
+[228Ra]
+[127Sb]
+[136Cs]
+[125Sb]
+[134Cs]
+[140Ba]
+[45Ca]
+[206Pb]
+[207Pb]
+[24Na]
+[86Rb]
+[212Bi]
+[208Pb]
+[124Sb]
+[204Pb]
+[44K]
+[129Te]
+[113Sn]
+[204Tl]
+[87Sr]
+[208Tl]
+[87Rb]
+[47Ca]
+[135Cs]
+[216Po]
+[137Ba]
+[207Bi]
+[212Po]
+[79Se]
+[223Ra]
+[86Sr]
+[122Sb]
+[26Al]
+[32Si]
+[126Sn]
+[225Ra]
+[114In]
+[72Ga]
+[132Te]
+[10Be]
+[125Sn]
+[73As]
+[206Bi]
+[117Sn]
+[40Ca]
+[41Ca]
+[89Rb]
+[116In]
+[129Sb]
+[91Sr]
+[71Ge]
+[139Ba]
+[69Ga]
+[120Sb]
+[121Sn]
+[123Sn]
+[131Te]
+[77Ge]
+[135Ba]
+[82Sr]
+[43K]
+[131Ba]
+[92Sr]
+[88Rb]
+[129Cs]
+[144Cs]
+[127Cs]
+[200Tl]
+[202Tl]
+[141Ba]
+[117Sb]
+[116Sb]
+[78As]
+[131Sb]
+[126Sb]
+[128Sb]
+[130Sb]
+[67Ge]
+[68Ge]
+[78Ge]
+[66Ge]
+[223Fr]
+[132Cs]
+[125Cs]
+[138Cs]
+[133Te]
+[84Rb]
+[83Rb]
+[81Rb]
+[142Ba]
+[200Bi]
+[115Sb]
+[194Tl]
+[70Se]
+[112In]
+[118Sb]
+[70Ga]
+[27Mg]
+[202Bi]
+[83Se]
+[9Li]
+[69As]
+[79Rb]
+[81Sr]
+[83Sr]
+[78Se]
+[109In]
+[29Al]
+[118Sn]
+[117In]
+[119Sb]
+[114Sn]
+[138Ba]
+[69Ge]
+[73Ga]
+[74Ge]
+[206Tl]
+[199Tl]
+[130Cs]
+[28Mg]
+[116Te]
+[112Sn]
+[126Ba]
+[211Bi]
+[81Se]
+[127Sn]
+[143Cs]
+[134Te]
+[80Sr]
+[45K]
+[215Po]
+[207Po]
+[111Sn]
+[211Po]
+[128Ba]
+[198Tl]
+[227Ra]
+[213Po]
+[220Ra]
+[128Sn]
+[203Po]
+[205Po]
+[65Ga]
+[197Tl]
+[88Sr]
+[110In]
+[31Si]
+[201Bi]
+[121Te]
+[205Bi]
+[203Bi]
+[195Tl]
+[209Tl]
+[110Sn]
+[222Fr]
+[207At]
+[119In]
+[As@]
+[129IH]
+[157Dy]
+[111IH]
+[230Ra]
+[144Pr+3]
+[SiH3+]
+[3He]
+[AsH5]
+[72Se]
+[95Tc]
+[103Pd]
+[121Sn+2]
+[211Rn]
+[38SH2]
+[127IH]
+[74Br-]
+[133I-]
+[100Tc+4]
+[100Tc]
+[36Cl-]
+[89Y+3]
+[104Rh]
+[152Sm]
+[226Ra]
+[19FH]
+[104Pd]
+[148Gd]
+[157Lu]
+[33SH2]
+[121I-]
+[17FH]
+[71Se]
+[157Sm]
+[148Tb]
+[164Dy]
+[15OH2]
+[15O+]
+[39K]
+[40Ar]
+[50Cr+3]
+[50Cr]
+[52Ti]
+[103Pd+2]
+[130Ba]
+[142Pm]
+[153Gd+3]
+[151Eu]
+[103Rh]
+[124Xe]
+[152Tb]
+[17OH2]
+[20Ne]
+[52Fe]
+[94Zr+4]
+[94Zr]
+[149Pr]
+[16OH2]
+[53Cr+6]
+[53Cr]
+[81Br-]
+[112Pd]
+[125Xe]
+[155Gd]
+[157Gd]
+[168Yb]
+[184Os]
+[166Tb]
+[221Fr]
+[212Ra]
+[75Br-]
+[79Br-]
+[113Ag]
+[23Na]
+[34Cl-]
+[34ClH]
+[38Cl-]
+[56Fe]
+[68Cu]
+[77Br-]
+[90Zr+4]
+[90Zr]
+[102Pd]
+[154Eu+3]
+[57Mn]
+[165Tm]
+[152Dy]
+[217At]
+[77se]
+[13cH-]
+[122Te]
+[156Gd]
+[124Te]
+[53Ni]
+[131Xe]
+[174Hf+4]
+[174Hf]
+[76Se]
+[168Tm]
+[167Dy]
+[154Gd]
+[95Ru]
+[210At]
+[85Br]
+[59Co]
+[122Xe]
+[27Al]
+[54Cr]
+[198Hg]
+[85Rb+]
+[214Tl]
+[229Rn]
+[218Pb]
+[218Bi]
+[167Tm+3]
+[18o+]
+[P@@H+]
+[P@H+]
+[13N+]
+[212Pb+2]
+[217Bi]
+[249Cf+2]
+[18OH3+]
+[90Sr-]
+[Cf+3]
+[200Hg]
+[86Tc]
+[141Pr+3]
+[141Pr]
+[16nH]
+[14NH4+]
+[132Xe]
+[83Kr]
+[70Zn+2]
+[137Ba+2]
+[36Ar]
+[38Ar]
+[21Ne]
+[126Xe]
+[136Xe]
+[128Xe]
+[134Xe]
+[84Kr]
+[86Kr]
+[78Kr]
+[80Kr]
+[82Kr]
+[67Zn+2]
+[65Cu+2]
+[110Te]
+[58Fe+3]
+[142Nd]
+[38K]
+[198Au+3]
+[122IH]
+[38PH3]
+[130I-]
+[40K+]
+[38K+]
+[28Mg+2]
+[208Tl+]
+[13OH2]
+[198Bi]
+[192Bi]
+[194Bi]
+[196Bi]
+[132I-]
+[83Sr+2]
+[169Er+3]
+[122I-]
+[120I-]
+[92Sr+2]
+[126I-]
+[24Mg]
+[84Sr]
+[118Pd+2]
+[118Pd]
+[AsH4]
+[127I-]
+[9C-]
+[11CH3+]
+[17B]
+[7B]
+[4HH]
+[18C-]
+[22CH3-]
+[22CH4]
+[17C-]
+[15CH3]
+[16CH3]
+[11NH3]
+[21NH3]
+[11N-]
+[11NH]
+[16CH]
+[17CH2]
+[99Ru+2]
+[181Ta+2]
+[181Ta]
+[20CH]
+[32PH2]
+[55Fe+2]
+[SH3]
+[S@H]
+[Mn-]
+[IH4]
+[ThH]
+[GaH-]
+[BiH+]
+[EuH2]
+[FeH4-3]
+[FeH6]
+[IH5]
+[NiH+]
+[SrH2]
+[VH4]
+[YH3]
+[seH+]
+
diff --git a/models/smi_ted/training/pubchem_canon_script.py b/models/smi_ted/training/pubchem_canon_script.py
new file mode 100644
index 0000000000000000000000000000000000000000..a26146bd42fc365db5d534226d77499b5e960e9c
--- /dev/null
+++ b/models/smi_ted/training/pubchem_canon_script.py
@@ -0,0 +1,71 @@
+import logging
+from dataclasses import dataclass
+import pyarrow as pa
+
+import datasets
+
+
+logger = logging.getLogger(__name__)
+
+
+FEATURES = datasets.Features(
+ {
+ "text": datasets.Value("string"),
+ }
+)
+
+
+@dataclass
+class PubChemConfig(datasets.BuilderConfig):
+ """BuilderConfig for text files."""
+
+ encoding: str = "utf-8"
+ chunksize: int = 10 << 20 # 10MB
+
+
+class PubChem(datasets.ArrowBasedBuilder):
+
+ BUILDER_CONFIG_CLASS = PubChemConfig
+
+ def _info(self):
+ return datasets.DatasetInfo(features=FEATURES)
+
+ def _split_generators(self, dl_manager):
+ """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]].
+
+ If str or List[str], then the dataset returns only the 'train' split.
+ If dict, then keys should be from the `datasets.Split` enum.
+ """
+ if not self.config.data_files:
+ raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
+ data_files = dl_manager.download_and_extract(self.config.data_files)
+ if isinstance(data_files, (str, list, tuple)):
+ files = data_files
+ if isinstance(files, str):
+ files = [files]
+ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
+ splits = []
+ for split_name, files in data_files.items():
+ if isinstance(files, str):
+ files = [files]
+ splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
+ return splits
+
+ def _generate_tables(self, files):
+
+ for file_idx, file in enumerate(files):
+ batch_idx = 0
+ with open(file, "r", encoding=self.config.encoding) as f:
+ while True:
+ batch = f.read(self.config.chunksize)
+ if not batch:
+ break
+ batch += f.readline() # finish current line
+ batch = batch.splitlines()
+ #batch = [word.split()[-1] for word in batch]
+ pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()}))
+ # Uncomment for debugging (will print the Arrow table size and elements)
+ #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
+ #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
+ yield (file_idx, batch_idx), pa_table
+ batch_idx += 1
diff --git a/models/smi_ted/training/pubchem_canon_script.py.lock b/models/smi_ted/training/pubchem_canon_script.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/smi_ted/training/pubchem_canon_zinc_final_vocab_sorted_curated.pth b/models/smi_ted/training/pubchem_canon_zinc_final_vocab_sorted_curated.pth
new file mode 100644
index 0000000000000000000000000000000000000000..ab685a3a5a7c95a8900dc186d7d72fae5cb73bc8
--- /dev/null
+++ b/models/smi_ted/training/pubchem_canon_zinc_final_vocab_sorted_curated.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc71f36557571ecf91d3e82a917113692974b0ef9dfc73bcfaaf0e2c080eaf09
+size 43635
diff --git a/models/smi_ted/training/pubchem_encoder.py b/models/smi_ted/training/pubchem_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ea38ea90fbfe9e69a93fbbadccd9180db31342b
--- /dev/null
+++ b/models/smi_ted/training/pubchem_encoder.py
@@ -0,0 +1,235 @@
+import regex as re
+import torch
+import numpy as np
+import random
+import collections
+
+class Encoder():
+
+ def __init__(self, max_length=500, add_bos=True, add_eos=True, feature_size=32):
+ self.vocab_encoder = torch.load('pubchem_canon_zinc_final_vocab_sorted_curated.pth')
+
+ self.max_length = max_length
+ self.min_length = 1
+ self.mod_length = 42
+ self.mlm_probability = .15
+ self.avg_length = 66
+ self.tail = 122
+ self.b0_cache=collections.deque()
+ self.b1_cache=collections.deque()
+ self.b2_cache=collections.deque()
+ self.b3_cache=collections.deque()
+ self.bucket0=collections.deque()
+ self.bucket1=collections.deque()
+ self.bucket2=collections.deque()
+ self.bucket3=collections.deque()
+ if feature_size == 32:
+ self.b0_max=1100
+ self.b1_max=700
+ self.b2_max=150
+ self.b3_max=50
+ else:
+ self.b0_max=1382
+ self.b1_max=871
+ self.b2_max=516
+ self.b3_max=311
+ values = list(self.vocab_encoder.values())
+ num_top = 0
+ middle_top = 0
+ bottom = 0
+ for count in values:
+ if count > 100000:
+ num_top += 1
+ if count > 50:
+ middle_top += 1
+ middle_top = middle_top - num_top
+ self.cutoffs = [num_top+4, middle_top]
+ self.char2id = {"":0, "":1, "":2, "":3}
+ self.id2char = {0:"", 1:"", 2:"", 3:""}
+ self.pad = self.char2id['']
+ self.mask = self.char2id['']
+ self.eos = self.char2id['']
+ self.bos = self.char2id['']
+ pos = 0
+ for key, value in self.vocab_encoder.items():
+ #for pos, key in enumerate(self.vocab_encoder.keys()):
+ self.char2id[key] = pos+4
+ self.id2char[pos+4] = key
+ pos += 1
+ self.char2id[""] = pos + 4
+ self.id2char[pos+4] = ""
+ self.pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
+ self.regex = re.compile(self.pattern)
+ self.add_bos = add_bos
+ self.add_eos = add_eos
+ #print(self.char2id)
+
+ def encode(self, char):
+ #if len(char) > self.max_length:
+ # char = char[:self.max_length]
+ if self.add_bos == True:
+ char = [''] + char
+ if self.add_eos == True:
+ char = char + ['']
+
+ return torch.tensor([self.char2id.get(word, self.char2id[""]) for word in char])
+
+ def encoder(self, tokens):
+ #return *map(lambda x: self.encode(x), tokens)
+ return [self.encode(mol) for mol in tokens]
+
+ def process_text(self, text):
+ #print(text)
+ #random length sequences seems to help training
+ mod_length = self.mod_length #+ random.randint(-1, 3)
+ avg_length = self.avg_length #+ random.randint(-3, 5)
+ for mol in text:
+ #fill up buckets and caches
+ if '\n' in mol['text']:
+ print('carriage return in mol')
+ raw_regex = self.regex.findall(mol['text'].strip('\n'))
+ length = len(raw_regex)
+ if length > self.min_length and length < mod_length:
+ if len(self.bucket0) < self.b0_max:
+ self.bucket0.append(raw_regex)
+ else:
+ self.b0_cache.append(raw_regex)
+ elif length >= mod_length and length < avg_length:
+ if len(self.bucket1) < self.b1_max:
+ self.bucket1.append(raw_regex)
+ else:
+ self.b1_cache.append(raw_regex)
+ elif length >= avg_length and length < self.tail:
+ if len(self.bucket2) < self.b2_max:
+ self.bucket2.append(raw_regex)
+ else:
+ self.b2_cache.append(raw_regex)
+ elif length >= self.tail and length < self.max_length:
+ if len(self.bucket3) < self.b3_max:
+ self.bucket3.append(raw_regex)
+ else:
+ self.b3_cache.append(raw_regex)
+ # elif length >= avg_length and length < self.tail:
+ # self.b2_cache.append(raw_regex)
+ # #if len(bucket2) < self.b2_max:
+ # # bucket2.append(raw_regex)
+ # #else:
+ # # self.b2_cache.append(raw_regex)
+ # elif length >= self.tail and length < self.max_length:
+ # self.b3_cache.append(raw_regex)
+ # #if len(bucket3) < self.b3_max:
+ # # bucket3.append(raw_regex)
+ # #else:
+ # # self.b3_cache.append(raw_regex)
+
+ #print('before Cache size {} {} {} {}'.format(len(self.b0_cache), len(self.b1_cache), len(self.b2_cache), len(self.b3_cache)))
+ #pour cache elements into any open bucket
+ if len(self.bucket0) < self.b0_max and len(self.b0_cache) > 0:
+ cache_size = len(self.b0_cache)
+ max_margin = self.b0_max-len(self.bucket0)
+ range0 = min(cache_size, max_margin)
+ outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))] + [self.b0_cache.pop() for i in range(range0)]
+ #self.b0_cache = collections.deque(self.b0_cache[:self.b0_max-len(bucket0)])
+ #print('0 type {}'.format(type(self.b0_cache)))
+ else:
+ outbucket0 = [self.bucket0.pop() for item in range(len(self.bucket0))]
+
+ if len(self.bucket1) < self.b1_max and len(self.b1_cache) > 0:
+ cache_size = len(self.b1_cache)
+ max_margin = self.b1_max-len(self.bucket1)
+ range1 = min(cache_size, max_margin)
+ outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))] + [self.b1_cache.pop() for i in range(range1)]
+ else:
+ outbucket1 = [self.bucket1.pop() for item in range(len(self.bucket1))]
+
+ if len(self.bucket2) < self.b2_max and len(self.b2_cache) > 0:
+ cache_size = len(self.b2_cache)
+ max_margin = self.b2_max-len(self.bucket2)
+ range2 = min(cache_size, max_margin)
+ outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))] + [self.b2_cache.pop() for i in range(range2)]
+ else:
+ outbucket2 = [self.bucket2.pop() for item in range(len(self.bucket2))]
+
+ if len(self.bucket3) < self.b3_max and len(self.b3_cache) > 0:
+ cache_size = len(self.b3_cache)
+ max_margin = self.b3_max-len(self.bucket3)
+ range3 = min(cache_size, max_margin)
+ outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))] + [self.b3_cache.pop() for i in range(range3)]
+ else:
+ outbucket3 = [self.bucket3.pop() for item in range(len(self.bucket3))]
+
+ # if len(self.b2_cache) > self.b2_max:
+ # cache_size = len(self.b2_cache)
+ # max_margin = self.b2_max
+ # range2 = min(cache_size, max_margin)
+ # outbucket2 = [self.b2_cache.pop() for i in range(range2)]
+ # else:
+ # outbucket2=[]
+
+ # if len(self.b3_cache) > self.b3_max:
+ # cache_size = len(self.b3_cache)
+ # max_margin = self.b3_max
+ # range3 = min(cache_size, max_margin)
+ # outbucket3 = [self.b3_cache.pop() for i in range(range3)]
+ # else:
+ # outbucket3 = []
+
+ return outbucket0, outbucket1, outbucket2, outbucket3
+
+ def mask_tokens( self, inputs, special_tokens_mask= None):
+ """
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
+ """
+ labels = inputs.clone()
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
+ probability_matrix = torch.full(labels.size(), self.mlm_probability)
+ if special_tokens_mask is None:
+ special_tokens_mask = [
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
+ ]
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
+ else:
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
+ #special_tokens_mask = special_tokens_mask.bool()
+
+ #print(special_tokens_mask.size())
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
+ masked_indices = torch.bernoulli(probability_matrix).bool()
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
+
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
+ indices_replaced = torch.bernoulli(torch.full(labels.size(), 0.8)).bool() & masked_indices
+ inputs[indices_replaced] = self.mask
+
+ # 10% of the time, we replace masked input tokens with random word
+ indices_random = torch.bernoulli(torch.full(labels.size(), 0.5)).bool() & masked_indices & ~indices_replaced
+ random_words = torch.randint(len(self.char2id.keys()), labels.size(), dtype=torch.long)
+ inputs[indices_random] = random_words[indices_random]
+
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
+ return inputs, labels
+ def pack_tensors(self, tokens):
+ array_ids = self.encoder(tokens)
+ array = torch.nn.utils.rnn.pad_sequence(array_ids, batch_first=True, padding_value=self.pad)
+ lengths = (array!=self.pad).sum(dim=-1)
+ #Bert tokenization
+ special_token_mask = [list(map(lambda x: 1 if x in [self.bos, self.eos, self.pad] else 0, stuff)) for stuff in array.tolist()]
+ masked_array, masked_labels = self.mask_tokens(array, special_token_mask)
+ return masked_array, masked_labels, array_ids, lengths
+ def process(self, text):
+ arrays = []
+ lengths = []
+ targets = []
+ arrays_ids = []
+ for tokens in self.process_text(text):
+ if len(tokens) > 0:
+ array, target, array_ids, lgt = self.pack_tensors(tokens)
+ arrays.append(array)
+ targets.append(target)
+ arrays_ids.append(array_ids)
+ lengths.append(lgt)
+ return arrays, targets, arrays_ids, lengths
+
+if __name__ == '__main__':
+
+ text_encoder = Encoder()
diff --git a/models/smi_ted/training/pubchem_script.py b/models/smi_ted/training/pubchem_script.py
new file mode 100644
index 0000000000000000000000000000000000000000..2164c3b7abf504a145f669fdb5a0d13fef25a14d
--- /dev/null
+++ b/models/smi_ted/training/pubchem_script.py
@@ -0,0 +1,71 @@
+import logging
+from dataclasses import dataclass
+import pyarrow as pa
+
+import datasets
+
+
+logger = logging.getLogger(__name__)
+
+
+FEATURES = datasets.Features(
+ {
+ "text": datasets.Value("string"),
+ }
+)
+
+
+@dataclass
+class PubChemConfig(datasets.BuilderConfig):
+ """BuilderConfig for text files."""
+
+ encoding: str = "utf-8"
+ chunksize: int = 10 << 20 # 10MB
+
+
+class PubChem(datasets.ArrowBasedBuilder):
+
+ BUILDER_CONFIG_CLASS = PubChemConfig
+
+ def _info(self):
+ return datasets.DatasetInfo(features=FEATURES)
+
+ def _split_generators(self, dl_manager):
+ """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]].
+
+ If str or List[str], then the dataset returns only the 'train' split.
+ If dict, then keys should be from the `datasets.Split` enum.
+ """
+ if not self.config.data_files:
+ raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
+ data_files = dl_manager.download_and_extract(self.config.data_files)
+ if isinstance(data_files, (str, list, tuple)):
+ files = data_files
+ if isinstance(files, str):
+ files = [files]
+ return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
+ splits = []
+ for split_name, files in data_files.items():
+ if isinstance(files, str):
+ files = [files]
+ splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
+ return splits
+
+ def _generate_tables(self, files):
+
+ for file_idx, file in enumerate(files):
+ batch_idx = 0
+ with open(file, "r", encoding=self.config.encoding) as f:
+ while True:
+ batch = f.read(self.config.chunksize)
+ if not batch:
+ break
+ batch += f.readline() # finish current line
+ batch = batch.splitlines()
+ batch = [word.split()[-1] for word in batch]
+ pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()}))
+ # Uncomment for debugging (will print the Arrow table size and elements)
+ #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
+ #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
+ yield (file_idx, batch_idx), pa_table
+ batch_idx += 1
diff --git a/models/smi_ted/training/pubchem_script.py.lock b/models/smi_ted/training/pubchem_script.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/smi_ted/training/run_model_large_training.sh b/models/smi_ted/training/run_model_large_training.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6ccd9eebe2015a6ebed6e5ca485b9966b9baadfd
--- /dev/null
+++ b/models/smi_ted/training/run_model_large_training.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+torchrun \
+ --standalone \
+ --nnodes=1 \
+ --nproc_per_node=1 \
+ train_model_D.py \
+ --device cuda \
+ --n_batch 48 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.2 \
+ --lr_start 3e-5 \
+ --lr_multiplier 4 \
+ --lr_decoder 3e-5 \
+ --n_workers 1 \
+ --max_epochs 51 \
+ --gpu -1 \
+ --num_nodes 1 \
+ --num_feats 32 \
+ --root_dir . \
+ --checkpoint_every 10000 \
+ --grad_acc 1 \
+ --train_load 'pubchem' \
+ --smi_ted_version 'v2' \
+ --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \
+ --save_checkpoint_path './large_checkpoints' \
+ --load_checkpoint_path '' \
+ --rotate \
+ --debug \
+ --model_arch 'BERT__both_rotate' \
\ No newline at end of file
diff --git a/models/smi_ted/training/run_model_light_training.sh b/models/smi_ted/training/run_model_light_training.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6abb952762608c842c31a711ee3551be38b33966
--- /dev/null
+++ b/models/smi_ted/training/run_model_light_training.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+torchrun \
+ --standalone \
+ --nnodes=1 \
+ --nproc_per_node=1 \
+ train_model_ED.py \
+ --device cuda \
+ --n_batch 288 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.2 \
+ --lr_start 3e-5 \
+ --lr_multiplier 4 \
+ --lr_decoder 3e-5 \
+ --n_workers 1 \
+ --max_epochs 51 \
+ --gpu -1 \
+ --num_nodes 1 \
+ --num_feats 32 \
+ --root_dir . \
+ --checkpoint_every 10000 \
+ --grad_acc 1 \
+ --train_load 'pubchem' \
+ --smi_ted_version 'v1' \
+ --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \
+ --save_checkpoint_path './light_checkpoints' \
+ --load_checkpoint_path '' \
+ --rotate \
+ --debug \
+ --model_arch 'BERT__both_rotate' \
\ No newline at end of file
diff --git a/models/smi_ted/training/send_job_large.slurm b/models/smi_ted/training/send_job_large.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..ede940a075d76f3a23a3dfae0865e649db008417
--- /dev/null
+++ b/models/smi_ted/training/send_job_large.slurm
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+# Example of running python script in a batch mode
+
+#SBATCH -J smi-ted-train
+#SBATCH -t 30:00:00
+#SBATCH -o output_smi_ted_large_epoch40_%j.out
+#SBATCH --mem=64G
+#SBATCH --nodes=10
+#SBATCH --ntasks=10
+#SBATCH --gpus-per-task=5
+#SBATCH --cpus-per-task=20
+
+nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
+nodes_array=($nodes)
+head_node=${nodes_array[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+
+echo Node IP: $head_node_ip
+export LOGLEVEL=INFO
+
+# Load software
+# module load anaconda3
+source /home/.bashrc
+conda activate smi-ted-env
+
+# Run python script
+srun torchrun \
+ --nnodes 10 \
+ --nproc_per_node 5 \
+ --rdzv_id $RANDOM \
+ --rdzv_backend c10d \
+ --rdzv_endpoint $head_node_ip:29500 \
+ train_model_D.py \
+ --device cuda \
+ --n_batch 48 \
+ --n_layer 24 \
+ --n_head 16 \
+ --n_embd 1024 \
+ --max_len 202 \
+ --d_dropout 0.2 \
+ --lr_start 3e-5 \
+ --lr_multiplier 4 \
+ --lr_decoder 3e-5 \
+ --n_workers 20 \
+ --max_epochs 51 \
+ --gpu -1 \
+ --num_nodes 1 \
+ --num_feats 32 \
+ --root_dir . \
+ --checkpoint_every 10000 \
+ --grad_acc 1 \
+ --train_load 'pubchem' \
+ --smi_ted_version 'v2' \
+ --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \
+ --save_checkpoint_path './large_checkpoints' \
+ --load_checkpoint_path '' \
+ --rotate \
+ --debug \
+ --model_arch 'BERT__both_rotate' \
\ No newline at end of file
diff --git a/models/smi_ted/training/send_job_light.slurm b/models/smi_ted/training/send_job_light.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..5fb4f24f99ce91c4f618fcb0312272669e506fea
--- /dev/null
+++ b/models/smi_ted/training/send_job_light.slurm
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+# Example of running python script in a batch mode
+
+#SBATCH -J smi-ted-train
+#SBATCH -t 6:00:00
+#SBATCH -o output_smi_ted_light_epoch50_%j.out
+#SBATCH --mem=64G
+#SBATCH --nodes=6
+#SBATCH --ntasks=6
+#SBATCH --gpus-per-task=4
+#SBATCH --cpus-per-task=12
+
+nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
+nodes_array=($nodes)
+head_node=${nodes_array[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+
+echo Node IP: $head_node_ip
+export LOGLEVEL=INFO
+
+# Load software
+# module load anaconda3
+source /home/.bashrc
+conda activate smi-ted-env
+
+# Run python script
+srun torchrun \
+ --nnodes 6 \
+ --nproc_per_node 4 \
+ --rdzv_id $RANDOM \
+ --rdzv_backend c10d \
+ --rdzv_endpoint $head_node_ip:29500 \
+ train_model_ED.py \
+ --device cuda \
+ --n_batch 288 \
+ --n_layer 12 \
+ --n_head 12 \
+ --n_embd 768 \
+ --max_len 202 \
+ --d_dropout 0.2 \
+ --lr_start 3e-5 \
+ --lr_multiplier 4 \
+ --lr_decoder 3e-5 \
+ --n_workers 12 \
+ --max_epochs 51 \
+ --gpu -1 \
+ --num_nodes 1 \
+ --num_feats 32 \
+ --root_dir . \
+ --checkpoint_every 10000 \
+ --grad_acc 1 \
+ --train_load 'pubchem' \
+ --smi_ted_version 'v1' \
+ --data_root './pubchem/pubchem_rd-canonical_smiles.smi' \
+ --save_checkpoint_path './light_checkpoints' \
+ --load_checkpoint_path '' \
+ --rotate \
+ --debug \
+ --model_arch 'BERT__both_rotate' \
\ No newline at end of file
diff --git a/models/smi_ted/training/smi_ted_large/load.py b/models/smi_ted/training/smi_ted_large/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..febd043c4d4fb5fc1664e13315d0e758365a383f
--- /dev/null
+++ b/models/smi_ted/training/smi_ted_large/load.py
@@ -0,0 +1,382 @@
+PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+
+# Transformers
+from fast_transformers.attention import AttentionLayer
+from fast_transformers.events import EventDispatcher, QKVEvent
+from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer
+from fast_transformers.builders.base import BaseBuilder
+from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder
+from fast_transformers.builders.attention_builders import AttentionBuilder
+from fast_transformers.feature_maps import GeneralizedRandomFeatures
+from fast_transformers.masking import LengthMask
+
+from transformers import BertTokenizer
+
+# Data
+import numpy as np
+
+# Standard library
+from functools import partial
+import regex as re
+import random
+
+
+class MolTranBertTokenizer(BertTokenizer):
+ def __init__(self, vocab_file: str = '',
+ do_lower_case=False,
+ unk_token='',
+ sep_token='',
+ pad_token='',
+ cls_token='',
+ mask_token='',
+ **kwargs):
+ super().__init__(vocab_file,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs)
+
+ self.regex_tokenizer = re.compile(PATTERN)
+ self.wordpiece_tokenizer = None
+ self.basic_tokenizer = None
+
+ def _tokenize(self, text):
+ split_tokens = self.regex_tokenizer.findall(text)
+ return split_tokens
+
+ def convert_idx_to_tokens(self, idx_tensor):
+ tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
+ return tokens
+
+ def convert_tokens_to_string(self, tokens):
+ stopwords = ['', '']
+ clean_tokens = [word for word in tokens if word not in stopwords]
+ out_string = ''.join(clean_tokens)
+ return out_string
+
+## Transformer layers
+
+class RotaryEmbedding(torch.nn.Module):
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = 0
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self.cos_cached = emb.cos()[None,:, None, :]
+ self.sin_cached = emb.sin()[None,:, None, :]
+
+ return self.cos_cached, self.sin_cached
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+@torch.jit.script
+def apply_rotary_pos_emb(q, k, cos, sin):
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+class RotateAttentionLayer(AttentionLayer):
+ """Rotate attention layer inherits from fast_transformer attention layer.
+ The only thing added is an Embedding encoding, for more information
+ on the attention layer see the fast_transformers code
+ """
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
+ d_values=None, event_dispatcher=""):
+ super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys,
+ d_values=d_values, event_dispatcher=event_dispatcher)
+
+ self.rotaryemb = RotaryEmbedding(d_keys)
+ print('Using Rotation Embedding')
+
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
+ key_lengths):
+ """
+ Using the same frame work as the fast_Transformers attention layer
+ but injecting rotary information to the queries and the keys
+ after the keys and queries are projected.
+ In the argument description we make use of the following sizes
+ - N: the batch size
+ - L: The maximum length of the queries
+ - S: The maximum length of the keys (the actual length per sequence
+ is given by the length mask)
+ - D: The input feature dimensionality passed in the constructor as
+ 'd_model'
+ Arguments
+ ---------
+ queries: (N, L, D) The tensor containing the queries
+ keys: (N, S, D) The tensor containing the keys
+ values: (N, S, D) The tensor containing the values
+ attn_mask: An implementation of BaseMask that encodes where each
+ query can attend to
+ query_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ key_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ Returns
+ -------
+ The new value for each query as a tensor of shape (N, L, D).
+ """
+ # Extract the dimensions into local variables
+ N, L, _ = queries.shape
+ _, S, _ = keys.shape
+ H = self.n_heads
+
+ # Project the queries/keys/values
+ queries = self.query_projection(queries).view(N, L, H, -1)
+ keys = self.key_projection(keys).view(N, S, H, -1)
+ cos, sin = self.rotaryemb(queries)
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
+ values = self.value_projection(values).view(N, S, H, -1)
+ # Let the world know of the qkv
+ self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
+
+
+ # Compute the attention
+ new_values = self.inner_attention(
+ queries,
+ keys,
+ values,
+ attn_mask,
+ query_lengths,
+ key_lengths
+ ).view(N, L, -1)
+
+ # Project the output and return
+ return self.out_projection(new_values)
+
+class RotateEncoderBuilder(BaseTransformerEncoderBuilder):
+ """Build a batch transformer encoder with Relative Rotary embeddings
+ for training or processing of sequences all elements at a time.
+ Example usage:
+ builder = RotateEncoderBuilder()
+ builder.n_layers = 12
+ builder.n_heads = 8
+ builder.feed_forward_dimensions = 1024
+ builder.query_dimensions = 64
+ builder.value_dimensions = 64
+ builder.dropout = 0.1
+ builder.attention_dropout = 0.1
+ builder.attention_type = "linear"
+ transformer = builder.get()
+ """
+ def _get_attention_builder(self):
+ """Return an instance of the appropriate attention builder."""
+ return AttentionBuilder()
+
+ def _get_attention_layer_class(self):
+ """Return the class for the layer that projects queries keys and
+ values."""
+ return RotateAttentionLayer
+
+ def _get_encoder_class(self):
+ """Return the class for the transformer encoder."""
+ return TransformerEncoder
+
+ def _get_encoder_layer_class(self):
+ """Return the class for the transformer encoder layer."""
+ return TransformerEncoderLayer
+
+
+class AutoEncoderLayer(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.encoder = self.Encoder(feature_size, latent_size)
+ self.decoder = self.Decoder(feature_size, latent_size)
+
+ class Encoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(feature_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.lat = nn.Linear(latent_size, latent_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.lat.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.lat(x)
+ return x # -> (N, D)
+
+ class Decoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(latent_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.rec = nn.Linear(latent_size, feature_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.rec.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.rec(x)
+ return x # -> (N, L*D)
+
+
+class LangLayer(nn.Module):
+ def __init__(self, n_embd, n_vocab):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.embed = nn.Linear(n_embd, n_embd)
+ self.ln_f = nn.LayerNorm(n_embd)
+ self.head = nn.Linear(n_embd, n_vocab, bias=False)
+ def forward(self, tensor):
+ if self.is_cuda_available:
+ self.embed.cuda()
+ self.ln_f.cuda()
+ self.head.cuda()
+ tensor = tensor.cuda()
+ tensor = self.embed(tensor)
+ tensor = F.gelu(tensor)
+ tensor = self.ln_f(tensor)
+ tensor = self.head(tensor)
+ return tensor
+
+
+class MoLEncoder(nn.Module):
+
+ def __init__(self, config, n_vocab):
+ super(MoLEncoder, self).__init__()
+
+ # embeddings
+ self.tok_emb = nn.Embedding(n_vocab, config.n_embd)
+ self.drop = nn.Dropout(config.d_dropout)
+
+ # transformer
+ builder = RotateEncoderBuilder.from_kwargs(
+ n_layers=config.n_layer,
+ n_heads=config.n_head,
+ query_dimensions=config.n_embd//config.n_head,
+ value_dimensions=config.n_embd//config.n_head,
+ feed_forward_dimensions=None,
+ attention_type='linear',
+ # unless we do deterministic_eval here, we will have random outputs
+ feature_map=partial(GeneralizedRandomFeatures,
+ n_dims=config.num_feats,
+ deterministic_eval=False),
+ activation='gelu'
+ )
+ self.blocks = builder.get()
+
+ # classification
+ self.lang_model = LangLayer(config.n_embd, n_vocab)
+
+ def forward(self, idx, mask=None, inference=False):
+ if not inference:
+ x = self.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.drop(x)
+
+ #masking of the length of the inputs its handled in the Masked language part of the code
+ #do not attempt to handle it in the forward of the transformer
+ x = self.blocks(x)
+ logits = self.lang_model(x)
+
+ return logits
+ else:
+ x = self.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.drop(x)
+
+ #masking of the length of the inputs its handled in the Masked language part of the code
+ #do not attempt to handle it in the forward of the transformer
+ x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1]))
+
+ # mean pooling
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+ true_set = sum_embeddings / sum_mask
+
+ return true_set, token_embeddings
+
+
+class MoLDecoder(nn.Module):
+
+ def __init__(self, n_vocab, max_len, n_embd, n_gpu=None):
+ super(MoLDecoder, self).__init__()
+
+ self.max_len = max_len
+ self.n_embd = n_embd
+ self.n_gpu = n_gpu
+ self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd)
+ self.lang_model = LangLayer(n_embd, n_vocab)
+
+ def forward(self, token_embeddings):
+ pred_set = self.autoencoder.encoder(token_embeddings) # (N, D)
+ pred_cte = self.autoencoder.decoder(pred_set) # (N, L*D)
+ pred_ids = self.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
+ return pred_set, pred_ids
+
+
+class Smi_ted(nn.Module):
+ """materials.smi-ted-Large 738M Parameters"""
+
+ def __init__(self, config, vocab):
+ super(Smi_ted, self).__init__()
+
+ self.config = config
+ self.padding_idx = 2
+ self.is_cuda_available = torch.cuda.is_available()
+ n_vocab = len(vocab.keys())
+ print(n_vocab, config.n_embd)
+
+ self.encoder = MoLEncoder(config, n_vocab)
+ self.decoder = MoLDecoder(n_vocab, config.max_len, config.n_embd)
+
+ self._set_seed(config.seed)
+ print('Vocab size:', n_vocab)
+ print(f'[PRE-TRAINING MODE - {str(self)}]')
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_seed(self, value):
+ print('Random Seed:', value)
+ random.seed(value)
+ torch.manual_seed(value)
+ torch.cuda.manual_seed(value)
+ torch.cuda.manual_seed_all(value)
+ np.random.seed(value)
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+
+ def __str__(self):
+ return 'smi-ted-Large'
\ No newline at end of file
diff --git a/models/smi_ted/training/smi_ted_light/load.py b/models/smi_ted/training/smi_ted_light/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..61598cbb1c4ff0d8f1aca5917a3dbd7cddcdf735
--- /dev/null
+++ b/models/smi_ted/training/smi_ted_light/load.py
@@ -0,0 +1,382 @@
+PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+
+# Transformers
+from fast_transformers.attention import AttentionLayer
+from fast_transformers.events import EventDispatcher, QKVEvent
+from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer
+from fast_transformers.builders.base import BaseBuilder
+from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder
+from fast_transformers.builders.attention_builders import AttentionBuilder
+from fast_transformers.feature_maps import GeneralizedRandomFeatures
+from fast_transformers.masking import LengthMask
+
+from transformers import BertTokenizer
+
+# Data
+import numpy as np
+
+# Standard library
+from functools import partial
+import regex as re
+import random
+
+
+class MolTranBertTokenizer(BertTokenizer):
+ def __init__(self, vocab_file: str = '',
+ do_lower_case=False,
+ unk_token='',
+ sep_token='',
+ pad_token='',
+ cls_token='',
+ mask_token='',
+ **kwargs):
+ super().__init__(vocab_file,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs)
+
+ self.regex_tokenizer = re.compile(PATTERN)
+ self.wordpiece_tokenizer = None
+ self.basic_tokenizer = None
+
+ def _tokenize(self, text):
+ split_tokens = self.regex_tokenizer.findall(text)
+ return split_tokens
+
+ def convert_idx_to_tokens(self, idx_tensor):
+ tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
+ return tokens
+
+ def convert_tokens_to_string(self, tokens):
+ stopwords = ['', '']
+ clean_tokens = [word for word in tokens if word not in stopwords]
+ out_string = ''.join(clean_tokens)
+ return out_string
+
+## Transformer layers
+
+class RotaryEmbedding(torch.nn.Module):
+
+ def __init__(self, dim, base=10000):
+ super().__init__()
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+ self.seq_len_cached = 0
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_dim=1):
+ seq_len = x.shape[seq_dim]
+ if seq_len != self.seq_len_cached:
+ self.seq_len_cached = seq_len
+
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self.cos_cached = emb.cos()[None,:, None, :]
+ self.sin_cached = emb.sin()[None,:, None, :]
+
+ return self.cos_cached, self.sin_cached
+
+def rotate_half(x):
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+@torch.jit.script
+def apply_rotary_pos_emb(q, k, cos, sin):
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+class RotateAttentionLayer(AttentionLayer):
+ """Rotate attention layer inherits from fast_transformer attention layer.
+ The only thing added is an Embedding encoding, for more information
+ on the attention layer see the fast_transformers code
+ """
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
+ d_values=None, event_dispatcher=""):
+ super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys,
+ d_values=d_values, event_dispatcher=event_dispatcher)
+
+ self.rotaryemb = RotaryEmbedding(d_keys)
+ print('Using Rotation Embedding')
+
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
+ key_lengths):
+ """
+ Using the same frame work as the fast_Transformers attention layer
+ but injecting rotary information to the queries and the keys
+ after the keys and queries are projected.
+ In the argument description we make use of the following sizes
+ - N: the batch size
+ - L: The maximum length of the queries
+ - S: The maximum length of the keys (the actual length per sequence
+ is given by the length mask)
+ - D: The input feature dimensionality passed in the constructor as
+ 'd_model'
+ Arguments
+ ---------
+ queries: (N, L, D) The tensor containing the queries
+ keys: (N, S, D) The tensor containing the keys
+ values: (N, S, D) The tensor containing the values
+ attn_mask: An implementation of BaseMask that encodes where each
+ query can attend to
+ query_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ key_lengths: An implementation of BaseMask that encodes how
+ many queries each sequence in the batch consists of
+ Returns
+ -------
+ The new value for each query as a tensor of shape (N, L, D).
+ """
+ # Extract the dimensions into local variables
+ N, L, _ = queries.shape
+ _, S, _ = keys.shape
+ H = self.n_heads
+
+ # Project the queries/keys/values
+ queries = self.query_projection(queries).view(N, L, H, -1)
+ keys = self.key_projection(keys).view(N, S, H, -1)
+ cos, sin = self.rotaryemb(queries)
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
+ values = self.value_projection(values).view(N, S, H, -1)
+ # Let the world know of the qkv
+ self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
+
+
+ # Compute the attention
+ new_values = self.inner_attention(
+ queries,
+ keys,
+ values,
+ attn_mask,
+ query_lengths,
+ key_lengths
+ ).view(N, L, -1)
+
+ # Project the output and return
+ return self.out_projection(new_values)
+
+class RotateEncoderBuilder(BaseTransformerEncoderBuilder):
+ """Build a batch transformer encoder with Relative Rotary embeddings
+ for training or processing of sequences all elements at a time.
+ Example usage:
+ builder = RotateEncoderBuilder()
+ builder.n_layers = 12
+ builder.n_heads = 8
+ builder.feed_forward_dimensions = 1024
+ builder.query_dimensions = 64
+ builder.value_dimensions = 64
+ builder.dropout = 0.1
+ builder.attention_dropout = 0.1
+ builder.attention_type = "linear"
+ transformer = builder.get()
+ """
+ def _get_attention_builder(self):
+ """Return an instance of the appropriate attention builder."""
+ return AttentionBuilder()
+
+ def _get_attention_layer_class(self):
+ """Return the class for the layer that projects queries keys and
+ values."""
+ return RotateAttentionLayer
+
+ def _get_encoder_class(self):
+ """Return the class for the transformer encoder."""
+ return TransformerEncoder
+
+ def _get_encoder_layer_class(self):
+ """Return the class for the transformer encoder layer."""
+ return TransformerEncoderLayer
+
+
+class AutoEncoderLayer(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.encoder = self.Encoder(feature_size, latent_size)
+ self.decoder = self.Decoder(feature_size, latent_size)
+
+ class Encoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(feature_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.lat = nn.Linear(latent_size, latent_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.lat.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.lat(x)
+ return x # -> (N, D)
+
+ class Decoder(nn.Module):
+
+ def __init__(self, feature_size, latent_size):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.fc1 = nn.Linear(latent_size, latent_size)
+ self.ln_f = nn.LayerNorm(latent_size)
+ self.rec = nn.Linear(latent_size, feature_size, bias=False)
+
+ def forward(self, x):
+ if self.is_cuda_available:
+ self.fc1.cuda()
+ self.ln_f.cuda()
+ self.rec.cuda()
+ x = x.cuda()
+ x = F.gelu(self.fc1(x))
+ x = self.ln_f(x)
+ x = self.rec(x)
+ return x # -> (N, L*D)
+
+
+class LangLayer(nn.Module):
+ def __init__(self, n_embd, n_vocab):
+ super().__init__()
+ self.is_cuda_available = torch.cuda.is_available()
+ self.embed = nn.Linear(n_embd, n_embd)
+ self.ln_f = nn.LayerNorm(n_embd)
+ self.head = nn.Linear(n_embd, n_vocab, bias=False)
+ def forward(self, tensor):
+ if self.is_cuda_available:
+ self.embed.cuda()
+ self.ln_f.cuda()
+ self.head.cuda()
+ tensor = tensor.cuda()
+ tensor = self.embed(tensor)
+ tensor = F.gelu(tensor)
+ tensor = self.ln_f(tensor)
+ tensor = self.head(tensor)
+ return tensor
+
+
+class MoLEncoder(nn.Module):
+
+ def __init__(self, config, n_vocab):
+ super(MoLEncoder, self).__init__()
+
+ # embeddings
+ self.tok_emb = nn.Embedding(n_vocab, config.n_embd)
+ self.drop = nn.Dropout(config.d_dropout)
+
+ # transformer
+ builder = RotateEncoderBuilder.from_kwargs(
+ n_layers=config.n_layer,
+ n_heads=config.n_head,
+ query_dimensions=config.n_embd//config.n_head,
+ value_dimensions=config.n_embd//config.n_head,
+ feed_forward_dimensions=config.n_embd,
+ attention_type='linear',
+ # unless we do deterministic_eval here, we will have random outputs
+ feature_map=partial(GeneralizedRandomFeatures,
+ n_dims=config.num_feats,
+ deterministic_eval=False),
+ activation='gelu'
+ )
+ self.blocks = builder.get()
+
+ # classification
+ self.lang_model = LangLayer(config.n_embd, n_vocab)
+
+ def forward(self, idx, mask=None, inference=False):
+ if not inference:
+ x = self.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.drop(x)
+
+ #masking of the length of the inputs its handled in the Masked language part of the code
+ #do not attempt to handle it in the forward of the transformer
+ x = self.blocks(x)
+ logits = self.lang_model(x)
+
+ return logits
+ else:
+ x = self.tok_emb(idx) # each index maps to a (learnable) vector
+ x = self.drop(x)
+
+ #masking of the length of the inputs its handled in the Masked language part of the code
+ #do not attempt to handle it in the forward of the transformer
+ x = self.blocks(x, length_mask=LengthMask(mask.sum(-1), max_len=idx.shape[1]))
+
+ # mean pooling
+ token_embeddings = x
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+ true_set = sum_embeddings / sum_mask
+
+ return true_set, token_embeddings
+
+
+class MoLDecoder(nn.Module):
+
+ def __init__(self, n_vocab, max_len, n_embd, n_gpu=None):
+ super(MoLDecoder, self).__init__()
+
+ self.max_len = max_len
+ self.n_embd = n_embd
+ self.n_gpu = n_gpu
+ self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd)
+ self.lang_model = LangLayer(n_embd, n_vocab)
+
+ def forward(self, token_embeddings):
+ pred_set = self.autoencoder.encoder(token_embeddings) # (N, D)
+ pred_cte = self.autoencoder.decoder(pred_set) # (N, L*D)
+ pred_ids = self.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
+ return pred_set, pred_ids
+
+
+class Smi_ted(nn.Module):
+ """materials.smi-ted-Light 289M Parameters"""
+
+ def __init__(self, config, vocab):
+ super(Smi_ted, self).__init__()
+
+ self.config = config
+ self.padding_idx = 2
+ self.is_cuda_available = torch.cuda.is_available()
+ n_vocab = len(vocab.keys())
+ print(n_vocab, config.n_embd)
+
+ self.encoder = MoLEncoder(config, n_vocab)
+ self.decoder = MoLDecoder(n_vocab, config.max_len, config.n_embd)
+
+ self._set_seed(config.seed)
+ print('Vocab size:', n_vocab)
+ print(f'[PRE-TRAINING MODE - {str(self)}]')
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_seed(self, value):
+ print('Random Seed:', value)
+ random.seed(value)
+ torch.manual_seed(value)
+ torch.cuda.manual_seed(value)
+ torch.cuda.manual_seed_all(value)
+ np.random.seed(value)
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+
+ def __str__(self):
+ return 'smi-ted-Light'
\ No newline at end of file
diff --git a/models/smi_ted/training/train_model_D.py b/models/smi_ted/training/train_model_D.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d5b5ab3300694f1c80cbe1e20187e3dcacd60b7
--- /dev/null
+++ b/models/smi_ted/training/train_model_D.py
@@ -0,0 +1,98 @@
+# This code uses the decoder loss directly.
+#
+#
+
+# Deep learning
+import torch
+from torch_optimizer.lamb import Lamb
+from trainer import TrainerDirectDecoder
+
+# Parallel
+from torch.utils.data.distributed import DistributedSampler
+from torch.distributed import init_process_group, destroy_process_group
+
+# Data
+from utils import MoleculeModule, get_optim_groups
+from torch.utils.data import DataLoader
+
+# Standard library
+import os
+import args
+
+
+def ddp_setup():
+ init_process_group(backend="nccl")
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+
+
+def load_train_objs(config):
+ # load data
+ train_loader = MoleculeModule(
+ config.max_len,
+ config.train_load,
+ config.data_root
+ )
+ train_loader.setup()
+
+ loader = DataLoader(
+ train_loader.pubchem,
+ batch_size=config.n_batch,
+ pin_memory=True,
+ shuffle=False,
+ collate_fn=train_loader.text_encoder.process,
+ sampler=DistributedSampler(train_loader.pubchem),
+ num_workers=config.n_workers
+ )
+
+ # load model
+ if config.smi_ted_version == 'v1':
+ from smi_ted_light.load import Smi_ted
+ elif config.smi_ted_version == 'v2':
+ from smi_ted_large.load import Smi_ted
+
+ model = Smi_ted(config, train_loader.get_vocab()).to('cuda')
+ model.apply(model._init_weights)
+
+ # load optimizer
+ optim_groups = get_optim_groups(model)
+ optimizer = torch.optim.AdamW(optim_groups, lr=config.lr_decoder, betas=(0.9, 0.99), fused=True)
+
+ return loader, model, optimizer
+
+
+def main(
+ config,
+ save_every: int,
+ total_epochs: int,
+ save_checkpoint_path: str,
+ load_checkpoint_path: str
+ ):
+ ddp_setup()
+
+ # training objects
+ train_data, model, optimizer = load_train_objs(config)
+
+ # init trainer
+ trainer = TrainerDirectDecoder(
+ model,
+ train_data,
+ optimizer,
+ save_every,
+ save_checkpoint_path,
+ load_checkpoint_path,
+ config
+ )
+ trainer.train(total_epochs)
+ destroy_process_group()
+
+
+if __name__ == '__main__':
+ parser = args.get_parser()
+ args = parser.parse_args()
+ main(
+ args,
+ args.checkpoint_every,
+ args.max_epochs,
+ save_checkpoint_path=args.save_checkpoint_path,
+ load_checkpoint_path=args.load_checkpoint_path,
+ )
diff --git a/models/smi_ted/training/train_model_ED.py b/models/smi_ted/training/train_model_ED.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca83d9caaaa673b93cc83542d479e319b9da2879
--- /dev/null
+++ b/models/smi_ted/training/train_model_ED.py
@@ -0,0 +1,100 @@
+# This code uses both encoder and decoder losses.
+#
+#
+
+# Deep learning
+import torch
+from torch_optimizer.lamb import Lamb
+from trainer import TrainerEncoderDecoder
+
+# Parallel
+from torch.utils.data.distributed import DistributedSampler
+from torch.distributed import init_process_group, destroy_process_group
+
+# Data
+from utils import MoleculeModule, get_optim_groups
+from torch.utils.data import DataLoader
+
+# Standard library
+import os
+import args
+
+
+def ddp_setup():
+ init_process_group(backend="nccl")
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+
+
+def load_train_objs(config):
+ # load data
+ train_loader = MoleculeModule(
+ config.max_len,
+ config.train_load,
+ config.data_root
+ )
+ train_loader.setup()
+
+ loader = DataLoader(
+ train_loader.pubchem,
+ batch_size=config.n_batch,
+ pin_memory=True,
+ shuffle=False,
+ collate_fn=train_loader.text_encoder.process,
+ sampler=DistributedSampler(train_loader.pubchem),
+ num_workers=config.n_workers
+ )
+
+ # load model
+ if config.smi_ted_version == 'v1':
+ from smi_ted_light.load import Smi_ted
+ elif config.smi_ted_version == 'v2':
+ from smi_ted_large.load import Smi_ted
+
+ model = Smi_ted(config, train_loader.get_vocab())
+ model.apply(model._init_weights)
+
+ # load optimizer
+ optim_groupsE = get_optim_groups(model.encoder)
+ optim_groupsD = get_optim_groups(model.decoder)
+ optimizerE = Lamb(optim_groupsE, lr=config.lr_start*config.lr_multiplier, betas=(0.9, 0.99))
+ optimizerD = torch.optim.Adam(optim_groupsD, lr=config.lr_decoder, betas=(0.9, 0.99))
+
+ return loader, model, (optimizerE, optimizerD)
+
+
+def main(
+ config,
+ save_every: int,
+ total_epochs: int,
+ save_checkpoint_path: str,
+ load_checkpoint_path: str
+ ):
+ ddp_setup()
+
+ # training objects
+ train_data, model, optimizers = load_train_objs(config)
+
+ # init trainer
+ trainer = TrainerEncoderDecoder(
+ model,
+ train_data,
+ optimizers,
+ save_every,
+ save_checkpoint_path,
+ load_checkpoint_path,
+ config
+ )
+ trainer.train(total_epochs)
+ destroy_process_group()
+
+
+if __name__ == '__main__':
+ parser = args.get_parser()
+ args = parser.parse_args()
+ main(
+ args,
+ args.checkpoint_every,
+ args.max_epochs,
+ save_checkpoint_path=args.save_checkpoint_path,
+ load_checkpoint_path=args.load_checkpoint_path,
+ )
diff --git a/models/smi_ted/training/trainer.py b/models/smi_ted/training/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffff34bb51c0896e7bf7f1dffe047ae15bbcdb41
--- /dev/null
+++ b/models/smi_ted/training/trainer.py
@@ -0,0 +1,454 @@
+# Deep learning
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from torch.utils.data import DataLoader
+from torch.nn.parallel import DistributedDataParallel as DDP
+from fast_transformers.masking import LengthMask
+
+# Standard library
+from tqdm import tqdm
+import pandas as pd
+import numpy as np
+import random
+import os
+
+
+class Trainer:
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ train_data: DataLoader,
+ optimizer: torch.optim.Optimizer,
+ save_every: int,
+ save_checkpoint_path: str,
+ load_checkpoint_path: str,
+ config,
+ ) -> None:
+ self.local_rank = int(os.environ["LOCAL_RANK"])
+ self.global_rank = int(os.environ["RANK"])
+ self.model = model.to(self.local_rank)
+ self.train_data = train_data
+ self.optimizer = optimizer
+ self.save_every = save_every
+ self.epochs_run = 0
+ self.last_batch_idx = -1
+ self.save_checkpoint_path = save_checkpoint_path
+ self.config = config
+
+ if os.path.exists(load_checkpoint_path):
+ print(f"Loading checkpoint at {load_checkpoint_path}...")
+ self._load_checkpoint(load_checkpoint_path)
+
+ self.model = DDP(self.model, device_ids=[self.local_rank])
+
+ def _load_checkpoint(self, checkpoint_path):
+ opt_dict = None
+ loc = f"cuda:{self.local_rank}"
+ ckpt_dict = torch.load(checkpoint_path, map_location=loc)
+ if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
+ opt_dict = torch.load(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt'), map_location=loc)
+
+ self.model.load_state_dict(ckpt_dict["MODEL_STATE"])
+ if opt_dict is not None:
+ self.optimizer.load_state_dict(opt_dict["OPTIMIZER_STATE"])
+ print('Optimizer states restored!')
+
+ self.last_batch_idx = ckpt_dict["last_batch_idx"] if 'last_batch_idx' in ckpt_dict else -1
+ self.epochs_run = ckpt_dict["EPOCHS_RUN"] + 1 if self.last_batch_idx == -1 else ckpt_dict["EPOCHS_RUN"]
+
+ # load RNG states each time the model and states are loaded from checkpoint
+ if 'rng' in ckpt_dict:
+ rng = ckpt_dict['rng']
+ for key, value in rng.items():
+ if key =='torch_state':
+ torch.set_rng_state(value.cpu())
+ elif key =='cuda_state':
+ torch.cuda.set_rng_state(value.cpu())
+ elif key =='numpy_state':
+ np.random.set_state(value)
+ elif key =='python_state':
+ random.setstate(value)
+ else:
+ print('unrecognized state')
+
+ print(f"Resuming training from checkpoint at Epoch {self.epochs_run}.")
+
+ def _save_checkpoint(self, epoch, config, last_idx):
+ # save RNG states each time the model and states are saved
+ out_dict = dict()
+ out_dict['torch_state'] = torch.get_rng_state()
+ out_dict['cuda_state'] = torch.cuda.get_rng_state()
+ if np:
+ out_dict['numpy_state'] = np.random.get_state()
+ if random:
+ out_dict['python_state'] = random.getstate()
+
+ # model states
+ ckpt_dict = {
+ "MODEL_STATE": self.model.module.state_dict(),
+ "EPOCHS_RUN": epoch,
+ "hparams": vars(config),
+ "last_batch_idx": last_idx,
+ "rng": out_dict
+ }
+
+ # optimizer states
+ opt_dict = {
+ "OPTIMIZER_STATE": self.optimizer.state_dict(),
+ }
+
+ if last_idx == -1:
+ filename = f'{str(self.model.module)}_{epoch}.pt'
+ else:
+ filename = f'{str(self.model.module)}_{last_idx}_{epoch}.pt'
+
+ torch.save(ckpt_dict, os.path.join(self.save_checkpoint_path, filename))
+ torch.save(opt_dict, os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt'))
+
+ print(f"Epoch {epoch} | Training checkpoint saved at {os.path.join(self.save_checkpoint_path, filename)}.")
+
+ def train(self, max_epochs: int):
+ for epoch in range(self.epochs_run, max_epochs):
+ self._run_epoch(epoch)
+ if self.local_rank == 0:
+ self._save_checkpoint(epoch, self.config, last_idx=-1)
+
+ def _run_epoch(self, epoch):
+ print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {self.config.n_batch} | Steps: {len(self.train_data)} | Last batch: {self.last_batch_idx}")
+ self.train_data.sampler.set_epoch(epoch)
+ loss_list = pd.Series()
+
+ for idx, data in enumerate(tqdm(self.train_data)):
+ # skip batches
+ if idx <= self.last_batch_idx:
+ continue
+
+ # run batch
+ bucket_idx_masked = data[0]
+ bucket_targets = data[1]
+ bucket_idx_not_masked = data[2]
+ loss = self._run_batch(bucket_idx_masked, bucket_targets, bucket_idx_not_masked)
+ torch.cuda.empty_cache()
+
+ # track loss
+ if self.local_rank == 0:
+ loss_list = pd.concat([loss_list, pd.Series([loss])], axis=0)
+
+ # checkpoint
+ if self.local_rank == 0 and idx % self.save_every == 0 and idx != 0:
+ self._save_checkpoint(epoch, self.config, idx)
+ # WARN: due to job limit time - save loss for each iter
+ loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_{idx}_epoch{epoch}.csv'), index=False)
+ loss_list = pd.Series()
+
+ self.last_batch_idx = -1
+
+ if self.local_rank == 0:
+ loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
+
+ def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
+ raise NotImplementedError
+
+
+class TrainerEncoderDecoder(Trainer):
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ train_data: DataLoader,
+ optimizer: torch.optim.Optimizer,
+ save_every: int,
+ save_checkpoint_path: str,
+ load_checkpoint_path: str,
+ config,
+ ) -> None:
+ super().__init__(model, train_data, optimizer, save_every, save_checkpoint_path, load_checkpoint_path, config)
+ self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
+ self.criterionR = nn.MSELoss()
+
+ self.optimE = self.optimizer[0]
+ self.optimD = self.optimizer[1]
+
+ self.ngpus_per_node = torch.cuda.device_count()
+ self.total_batches = len(self.train_data)
+ self.batch_thresh = int(self.total_batches - (self.total_batches * 0.05 * self.ngpus_per_node))
+ print('batch_thresh:', self.batch_thresh)
+
+ def _load_checkpoint(self, checkpoint_path):
+ opt_dict = None
+ loc = f"cuda:{self.local_rank}"
+ ckpt_dict = torch.load(checkpoint_path, map_location=loc)
+ if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
+ opt_dict = torch.load(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt'), map_location=loc)
+
+ self.model.load_state_dict(ckpt_dict["MODEL_STATE"])
+ if opt_dict is not None:
+ self.optimizer[0].load_state_dict(opt_dict["OPTIMIZER_STATE_ENCODER"])
+ self.optimizer[1].load_state_dict(opt_dict["OPTIMIZER_STATE_DECODER"])
+ print('Optimizer states restored!')
+
+ self.last_batch_idx = ckpt_dict["last_batch_idx"] if 'last_batch_idx' in ckpt_dict else -1
+ self.epochs_run = ckpt_dict["EPOCHS_RUN"] + 1 if self.last_batch_idx == -1 else ckpt_dict["EPOCHS_RUN"]
+
+ # load RNG states each time the model and states are loaded from checkpoint
+ if 'rng' in ckpt_dict:
+ rng = ckpt_dict['rng']
+ for key, value in rng.items():
+ if key =='torch_state':
+ torch.set_rng_state(value.cpu())
+ elif key =='cuda_state':
+ torch.cuda.set_rng_state(value.cpu())
+ elif key =='numpy_state':
+ np.random.set_state(value)
+ elif key =='python_state':
+ random.setstate(value)
+ else:
+ print('unrecognized state')
+
+ print(f"Resuming training from checkpoint at Epoch {self.epochs_run}.")
+
+ def _save_checkpoint(self, epoch, config, last_idx):
+ # save RNG states each time the model and states are saved
+ out_dict = dict()
+ out_dict['torch_state'] = torch.get_rng_state()
+ out_dict['cuda_state'] = torch.cuda.get_rng_state()
+ if np:
+ out_dict['numpy_state'] = np.random.get_state()
+ if random:
+ out_dict['python_state'] = random.getstate()
+
+ # model states
+ ckpt_dict = {
+ "MODEL_STATE": self.model.module.state_dict(),
+ "EPOCHS_RUN": epoch,
+ "hparams": vars(config),
+ "last_batch_idx": last_idx,
+ "rng": out_dict
+ }
+
+ # optimizer states
+ opt_dict = {
+ "OPTIMIZER_STATE_ENCODER": self.optimizer[0].state_dict(),
+ "OPTIMIZER_STATE_DECODER": self.optimizer[1].state_dict(),
+ }
+
+ if last_idx == -1:
+ filename = f'{str(self.model.module)}_{epoch}.pt'
+ else:
+ filename = f'{str(self.model.module)}_{last_idx}_{epoch}.pt'
+
+ torch.save(ckpt_dict, os.path.join(self.save_checkpoint_path, filename))
+ torch.save(opt_dict, os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt'))
+
+ print(f"Epoch {epoch} | Training checkpoint saved at {os.path.join(self.save_checkpoint_path, filename)}.")
+
+ def _run_epoch(self, epoch):
+ print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {self.config.n_batch} | Steps: {len(self.train_data)}")
+ self.train_data.sampler.set_epoch(epoch)
+ loss_list = pd.DataFrame()
+
+ for idx, data in enumerate(tqdm(self.train_data)):
+ bucket_idx_masked = data[0]
+ bucket_targets = data[1]
+ bucket_idx_not_masked = data[2]
+ lossE, lossD = self._run_batch(idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked)
+ torch.cuda.empty_cache()
+
+ if self.local_rank == 0:
+ df = pd.DataFrame({
+ 'lossE': [lossE.cpu().item()],
+ 'lossD': [lossD.cpu().item()],
+ })
+ loss_list = pd.concat([loss_list, df], axis=0)
+
+ if self.local_rank == 0:
+ loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
+
+ def custom(self, module):
+ def custom_forward(*inputs):
+ inputs = module(inputs[0])
+ return inputs
+ return custom_forward
+
+ def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
+ self.optimE.zero_grad(set_to_none=True)
+ self.optimD.zero_grad(set_to_none=True)
+
+ can_train_encoder = (batch_idx + 1) <= self.batch_thresh
+ can_train_decoder = (batch_idx + 1) > self.batch_thresh
+
+ padding_idx = 2
+ errorE = torch.zeros(1).to(self.local_rank)
+ errorD = torch.zeros(1).to(self.local_rank)
+ errorE_tmp = .0
+ errorD_tmp = .0
+
+ for chunk in range(len(bucket_idx_masked)):
+ idx_masked = bucket_idx_masked[chunk].to(self.local_rank)
+ targets = bucket_targets[chunk].to(self.local_rank)
+ idx_not_masked = bucket_idx_not_masked[chunk]
+ idx_not_masked = list(map(lambda x: F.pad(x, pad=(0, self.config.max_len - x.shape[0]), value=2).unsqueeze(0), idx_not_masked))
+ idx_not_masked = torch.cat(idx_not_masked, dim=0).to(self.local_rank)
+ mask = (idx_masked != padding_idx)
+
+ ###########
+ # Encoder #
+ ###########
+ if can_train_encoder:
+ for param in self.model.module.encoder.parameters():
+ param.requires_grad = True
+ for param in self.model.module.decoder.parameters():
+ param.requires_grad = False
+
+ # encoder forward
+ x = self.model.module.encoder.tok_emb(idx_masked)
+ x = self.model.module.encoder.drop(x)
+ x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x)
+ logits = self.model.module.encoder.lang_model(x)
+
+ # loss function
+ logits = logits.view(-1, logits.size(-1))
+ targets = targets.view(-1)
+ errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked)
+
+ if chunk < len(bucket_idx_masked)-1:
+ errorE_tmp.backward()
+ errorE += errorE_tmp.detach()
+ else:
+ errorE += errorE_tmp
+
+
+ ###########
+ # Decoder #
+ ###########
+ if can_train_decoder:
+ for param in self.model.module.encoder.parameters():
+ param.requires_grad = False
+ for param in self.model.module.decoder.parameters():
+ param.requires_grad = True
+
+ self.model.module.encoder.eval()
+
+ # encoder forward
+ with torch.no_grad():
+ true_set, true_cte = self.model.module.encoder(idx_masked, mask=mask, inference=True)
+
+ # add padding
+ input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()
+ mask_embeddings = (true_cte * input_mask_expanded)
+ true_cte = F.pad(mask_embeddings, pad=(0, 0, 0, self.config.max_len - mask_embeddings.shape[1]), value=0)
+ true_cte = true_cte.view(-1, self.config.max_len*self.config.n_embd)
+
+ # decoder forward
+ pred_set, pred_ids = self.model.module.decoder(true_cte)
+
+ # losses
+ pred_ids = pred_ids.view(-1, pred_ids.size(-1))
+ true_ids = idx_not_masked.view(-1)
+
+ error_ids = self.criterionC(pred_ids, true_ids) / len(bucket_idx_masked)
+ error_set = self.criterionR(pred_set, true_set) / len(bucket_idx_masked)
+ errorD_tmp = error_ids + error_set
+
+ if chunk < len(bucket_idx_masked)-1:
+ errorD_tmp.backward()
+ errorD += errorD_tmp.detach()
+ else:
+ errorD += errorD_tmp
+
+ if can_train_decoder:
+ errorD.backward()
+ self.optimD.step()
+ elif can_train_encoder:
+ errorE.backward()
+ self.optimE.step()
+
+ if self.local_rank == 0:
+ print(f'LossE: {errorE.item()} | LossD: {errorD.item()}')
+ return errorE, errorD
+
+
+class TrainerDirectDecoder(Trainer):
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ train_data: DataLoader,
+ optimizer: torch.optim.Optimizer,
+ save_every: int,
+ save_checkpoint_path: str,
+ load_checkpoint_path: str,
+ config,
+ ) -> None:
+ super().__init__(model, train_data, optimizer, save_every, save_checkpoint_path, load_checkpoint_path, config)
+ self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
+ self.criterionR = nn.MSELoss()
+
+ def custom(self, module):
+ def custom_forward(*inputs):
+ inputs = module(inputs[0], length_mask=inputs[1])
+ return inputs
+ return custom_forward
+
+ def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
+ padding_idx = 2
+ error = torch.zeros(1).to(self.local_rank)
+ error_tmp = .0
+ self.optimizer.zero_grad(set_to_none=True)
+
+ for chunk in range(len(bucket_idx_masked)):
+ idx_masked = bucket_idx_masked[chunk].to(self.local_rank)
+ targets = bucket_targets[chunk].to(self.local_rank)
+ idx_not_masked = bucket_idx_not_masked[chunk]
+ idx_not_masked = list(map(lambda x: F.pad(x, pad=(0, self.config.max_len - x.shape[0]), value=2).unsqueeze(0), idx_not_masked))
+ idx_not_masked = torch.cat(idx_not_masked, dim=0).to(self.local_rank)
+ mask = (idx_masked != padding_idx)
+
+ # encoder forward
+ x = self.model.module.encoder.tok_emb(idx_masked)
+ x = self.model.module.encoder.drop(x)
+ x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x, LengthMask(mask.sum(-1), max_len=idx_masked.shape[1]))
+
+ # mean pooling
+ input_masked_expanded = mask.unsqueeze(-1).expand(x.size()).float()
+ sum_embeddings = torch.sum(x*input_masked_expanded, 1)
+ sum_mask = torch.clamp(input_masked_expanded.sum(1), min=1e-9)
+ true_set = sum_embeddings/sum_mask
+ true_cte = x
+ del x
+ torch.cuda.empty_cache()
+
+ # add padding
+ input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()
+ mask_embeddings = (true_cte * input_mask_expanded)
+ true_cte = F.pad(mask_embeddings, pad=(0, 0, 0, self.config.max_len - mask_embeddings.shape[1]), value=0)
+ true_cte = true_cte.view(-1, self.config.max_len*self.config.n_embd)
+
+ # decoder forward
+ pred_set, pred_ids = self.model.module.decoder(true_cte)
+
+ # losses
+ pred_ids = pred_ids.view(-1, pred_ids.size(-1))
+ true_ids = idx_not_masked.view(-1)
+
+ error_ids = self.criterionC(pred_ids, true_ids) / len(bucket_idx_masked)
+ error_set = self.criterionR(pred_set, true_set) / len(bucket_idx_masked)
+ error_tmp = error_ids + error_set
+
+ if chunk < len(bucket_idx_masked)-1:
+ error_tmp.backward()
+ error += error_tmp.detach()
+ else:
+ error += error_tmp
+
+ torch.cuda.empty_cache()
+
+ error.backward()
+ self.optimizer.step()
+
+ if self.local_rank == 0:
+ print(f'Loss: {error.item()}')
+ return error.item()
diff --git a/models/smi_ted/training/utils.py b/models/smi_ted/training/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0634255d5452391eeedecdbf007da92654fca8e
--- /dev/null
+++ b/models/smi_ted/training/utils.py
@@ -0,0 +1,96 @@
+# Deep learning
+import torch
+
+# Data
+from pubchem_encoder import Encoder
+from datasets import load_dataset
+
+# Standard library
+import os
+import getpass
+import glob
+
+
+class MoleculeModule:
+ def __init__(self, max_len, dataset, data_path):
+ super().__init__()
+ self.dataset = dataset
+ self.data_path = data_path
+ self.text_encoder = Encoder(max_len)
+
+ def prepare_data(self):
+ pass
+
+ def get_vocab(self):
+ #using home made tokenizer, should look into existing tokenizer
+ return self.text_encoder.char2id
+
+ def get_cache(self):
+ return self.cache_files
+
+ def setup(self, stage=None):
+ #using huggingface dataloader
+ # create cache in tmp directory of locale mabchine under the current users name to prevent locking issues
+ pubchem_path = {'train': self.data_path}
+ if 'canonical' in pubchem_path['train'].lower():
+ pubchem_script = './pubchem_canon_script.py'
+ else:
+ pubchem_script = './pubchem_script.py'
+ zinc_path = './data/ZINC'
+ global dataset_dict
+ if 'ZINC' in self.dataset or 'zinc' in self.dataset:
+ zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))]
+ for zfile in zinc_files:
+ print(zfile)
+ self.dataset = {'train': zinc_files}
+ dataset_dict = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True)
+
+ elif 'pubchem' in self.dataset:
+ dataset_dict = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'), split='train')
+ elif 'both' in self.dataset or 'Both' in self.dataset or 'BOTH' in self.dataset:
+ dataset_dict_pubchem = load_dataset(pubchem_script, data_files=pubchem_path, cache_dir=os.path.join('/tmp',getpass.getuser(), 'pubchem'),split='train', trust_remote_code=True)
+ zinc_files = [f for f in glob.glob(os.path.join(zinc_path,'*.smi'))]
+ for zfile in zinc_files:
+ print(zfile)
+ self.dataset = {'train': zinc_files}
+ dataset_dict_zinc = load_dataset('./zinc_script.py', data_files=self.dataset, cache_dir=os.path.join('/tmp',getpass.getuser(), 'zinc'),split='train', trust_remote_code=True)
+ dataset_dict = concatenate_datasets([dataset_dict_zinc, dataset_dict_pubchem])
+ self.pubchem= dataset_dict
+ print(dataset_dict.cache_files)
+ self.cache_files = []
+
+ for cache in dataset_dict.cache_files:
+ tmp = '/'.join(cache['filename'].split('/')[:4])
+ self.cache_files.append(tmp)
+
+
+def get_optim_groups(module):
+ # setup optimizer
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear,)
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in module.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in module.named_parameters()}
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+ ]
+
+ return optim_groups
\ No newline at end of file
diff --git a/representation/.gitattributes b/representation/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..3307846513fc22ee6369f98b69102d089545bdd5
--- /dev/null
+++ b/representation/.gitattributes
@@ -0,0 +1,9 @@
+bace_mhg.pkl filter=lfs diff=lfs merge=lfs -text
+esol_mhg.pkl filter=lfs diff=lfs merge=lfs -text
+esol_mol-xl.pkl filter=lfs diff=lfs merge=lfs -text
+bace_smi-ted.pkl filter=lfs diff=lfs merge=lfs -text
+esol_bart.pkl filter=lfs diff=lfs merge=lfs -text
+esol_smi-ted.pkl filter=lfs diff=lfs merge=lfs -text
+bace_MorganFingerprint.pkl filter=lfs diff=lfs merge=lfs -text
+bace_bart.pkl filter=lfs diff=lfs merge=lfs -text
+bace_mol-xl.pkl filter=lfs diff=lfs merge=lfs -text
diff --git a/representation/bace_MorganFingerprint.pkl b/representation/bace_MorganFingerprint.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..99011cc1c280a96acbbcfebfa10bbddff207a6b7
--- /dev/null
+++ b/representation/bace_MorganFingerprint.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b64ea75a5ac268dd9fd69c5ec85dd32358a15e2f9069a78c0c3c3148cf54f257
+size 11161440
diff --git a/representation/esol_MorganFingerprint.pkl b/representation/esol_MorganFingerprint.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..6b1d2fafdec4fe4c97710f8a7b395c03b42703c8
--- /dev/null
+++ b/representation/esol_MorganFingerprint.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f112772073dc4603631cf5845f9f13a10abb5c611687adde13a30a3537c00d3b
+size 7889672
diff --git a/requirements.txt b/requirements.txt
index 17053f0ff3b341f0fb5e925ef1da3112d2b84e6a..4f4bc906f62f3678e53f612462949f36e6b0e802 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -25,3 +25,4 @@ umap-learn
torch-optimizer
tqdm>=4.66.4
pandas==2.2.3
+mordred