File size: 3,101 Bytes
85ec4af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding:utf-8 -*-
# Rhizome
# Version beta 0.0, August 2023
# Property of IBM Research, Accelerated Discovery
#

import os
import pickle
import sys

from rdkit import Chem
import torch
from torch_geometric.utils.smiles import from_smiles

from typing import Any, Dict, List, Optional, Union
from typing_extensions import Self

from .graph_grammar.io.smi import hg_to_mol
from .models.mhgvae import GrammarGINVAE
from huggingface_hub import hf_hub_download


class PretrainedModelWrapper:
    model: GrammarGINVAE

    def __init__(self, model_dict: Dict[str, Any]) -> None:
        json_params = model_dict['gnn_params']
        encoder_params = json_params['encoder_params']
        encoder_params['node_feature_size'] = model_dict['num_features']
        encoder_params['edge_feature_size'] = model_dict['num_edge_features']
        self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
                                   decoder_params=json_params['decoder_params'],
                                   prod_rule_embed_params=json_params["prod_rule_embed_params"],
                                   batch_size=512, max_len=model_dict['max_length'])
        self.model.load_state_dict(model_dict['model_state_dict'])

        self.model.eval()

    def to(self, device: Union[str, int, torch.device]) -> Self:
        dev_type = type(device)
        if dev_type != torch.device:
            if dev_type == str or torch.cuda.is_available():
                device = torch.device(device)
            else:
                device = torch.device("mps", device)

        self.model = self.model.to(device)
        return self

    def encode(self, data: List[str]) -> List[torch.tensor]:
        # Need to encode them into a graph nn
        output = []
        for d in data:
            params = next(self.model.parameters())
            g = from_smiles(d)
            if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
                g.to(params.device)
            ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
            output.append(ltvec[0])
        return output

    def decode(self, data: List[torch.tensor]) -> List[str]:
        output = []
        for d in data:
            mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
            z = self.model.reparameterize(mu, logvar)
            flags, _, hgs = self.model.decode(z)
            if flags[0]:
                reconstructed_mol, _ = hg_to_mol(hgs[0], True)
                output.append(Chem.MolToSmiles(reconstructed_mol))
            else:
                output.append(None)
        return output


def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
    PretrainedModelWrapper]:
    repo_id = "ibm/materials.mhg-ged"
    filename = "mhggnn_pretrained_model_0724_2023.pickle"
    file_path = hf_hub_download(repo_id=repo_id, filename=filename)
    with open(file_path, "rb") as f:
        model_dict = pickle.load(f)
        return PretrainedModelWrapper(model_dict)
    return None