ipd's picture
edit mhg checkpoint
32b373a
raw
history blame
3.12 kB
# -*- coding:utf-8 -*-
# Rhizome
# Version beta 0.0, August 2023
# Property of IBM Research, Accelerated Discovery
#
import os
import pickle
import sys
from rdkit import Chem
import torch
from torch_geometric.utils.smiles import from_smiles
from typing import Any, Dict, List, Optional, Union
from typing_extensions import Self
from .graph_grammar.io.smi import hg_to_mol
from .models.mhgvae import GrammarGINVAE
from huggingface_hub import hf_hub_download
class PretrainedModelWrapper:
model: GrammarGINVAE
def __init__(self, model_dict: Dict[str, Any]) -> None:
json_params = model_dict['gnn_params']
encoder_params = json_params['encoder_params']
encoder_params['node_feature_size'] = model_dict['num_features']
encoder_params['edge_feature_size'] = model_dict['num_edge_features']
self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
decoder_params=json_params['decoder_params'],
prod_rule_embed_params=json_params["prod_rule_embed_params"],
batch_size=512, max_len=model_dict['max_length'])
self.model.load_state_dict(model_dict['model_state_dict'])
self.model.eval()
def to(self, device: Union[str, int, torch.device]) -> Self:
dev_type = type(device)
if dev_type != torch.device:
if dev_type == str or torch.cuda.is_available():
device = torch.device(device)
else:
device = torch.device("mps", device)
self.model = self.model.to(device)
return self
def encode(self, data: List[str]) -> List[torch.tensor]:
# Need to encode them into a graph nn
output = []
for d in data:
params = next(self.model.parameters())
g = from_smiles(d)
if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
g.to(params.device)
ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
output.append(ltvec[0])
return output
def decode(self, data: List[torch.tensor]) -> List[str]:
output = []
for d in data:
mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
z = self.model.reparameterize(mu, logvar)
flags, _, hgs = self.model.decode(z)
if flags[0]:
reconstructed_mol, _ = hg_to_mol(hgs[0], True)
output.append(Chem.MolToSmiles(reconstructed_mol))
else:
output.append(None)
return output
def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
PretrainedModelWrapper]:
repo_id = "ibm/materials.mhg-ged"
filename = "pytorch_model.bin" #"mhggnn_pretrained_model_0724_2023.pickle"
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(file_path, "rb") as f:
model_dict = torch.load(f)
return PretrainedModelWrapper(model_dict)
return None