Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Rhizome | |
# Version beta 0.0, August 2023 | |
# Property of IBM Research, Accelerated Discovery | |
# | |
""" | |
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) | |
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. | |
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. | |
""" | |
""" Title """ | |
__author__ = "Hiroshi Kajino <[email protected]>" | |
__copyright__ = "(c) Copyright IBM Corp. 2018" | |
__version__ = "0.1" | |
__date__ = "Jan 1 2018" | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus | |
from torch import nn | |
from torch.autograd import Variable | |
class MolecularProdRuleEmbedding(nn.Module): | |
''' molecular fingerprint layer | |
''' | |
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation, | |
out_dim=32, element_embed_dim=32, | |
num_layers=3, padding_idx=None, use_gpu=False): | |
super().__init__() | |
if padding_idx is not None: | |
assert padding_idx == -1, 'padding_idx must be -1.' | |
self.prod_rule_corpus = prod_rule_corpus | |
self.layer2layer_activation = layer2layer_activation | |
self.layer2out_activation = layer2out_activation | |
self.out_dim = out_dim | |
self.element_embed_dim = element_embed_dim | |
self.num_layers = num_layers | |
self.padding_idx = padding_idx | |
self.use_gpu = use_gpu | |
self.layer2layer_list = [] | |
self.layer2out_list = [] | |
if self.use_gpu: | |
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol, | |
self.element_embed_dim, requires_grad=True).cuda() | |
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol, | |
self.element_embed_dim, requires_grad=True).cuda() | |
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id, | |
self.element_embed_dim, requires_grad=True).cuda() | |
for _ in range(num_layers): | |
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda()) | |
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda()) | |
else: | |
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol, | |
self.element_embed_dim, requires_grad=True) | |
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol, | |
self.element_embed_dim, requires_grad=True) | |
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id, | |
self.element_embed_dim, requires_grad=True) | |
for _ in range(num_layers): | |
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim)) | |
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim)) | |
def forward(self, prod_rule_idx_seq): | |
''' forward model for mini-batch | |
Parameters | |
---------- | |
prod_rule_idx_seq : (batch_size, length) | |
Returns | |
------- | |
Variable, shape (batch_size, length, out_dim) | |
''' | |
batch_size, length = prod_rule_idx_seq.shape | |
if self.use_gpu: | |
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda() | |
else: | |
out = Variable(torch.zeros((batch_size, length, self.out_dim))) | |
for each_batch_idx in range(batch_size): | |
for each_idx in range(length): | |
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list): | |
continue | |
else: | |
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])] | |
layer_wise_embed_dict = {each_edge: self.atom_embed[ | |
each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']] | |
for each_edge in each_prod_rule.rhs.edges} | |
layer_wise_embed_dict.update({each_node: self.bond_embed[ | |
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']] | |
for each_node in each_prod_rule.rhs.nodes}) | |
for each_node in each_prod_rule.rhs.nodes: | |
if 'ext_id' in each_prod_rule.rhs.node_attr(each_node): | |
layer_wise_embed_dict[each_node] \ | |
= layer_wise_embed_dict[each_node] \ | |
+ self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']] | |
for each_layer in range(self.num_layers): | |
next_layer_embed_dict = {} | |
for each_edge in each_prod_rule.rhs.edges: | |
v = layer_wise_embed_dict[each_edge] | |
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge): | |
v = v + layer_wise_embed_dict[each_node] | |
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) | |
out[each_batch_idx, each_idx, :] \ | |
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v)) | |
for each_node in each_prod_rule.rhs.nodes: | |
v = layer_wise_embed_dict[each_node] | |
for each_edge in each_prod_rule.rhs.adj_edges(each_node): | |
v = v + layer_wise_embed_dict[each_edge] | |
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) | |
out[each_batch_idx, each_idx, :]\ | |
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v)) | |
layer_wise_embed_dict = next_layer_embed_dict | |
return out | |
class MolecularProdRuleEmbeddingLastLayer(nn.Module): | |
''' molecular fingerprint layer | |
''' | |
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation, | |
out_dim=32, element_embed_dim=32, | |
num_layers=3, padding_idx=None, use_gpu=False): | |
super().__init__() | |
if padding_idx is not None: | |
assert padding_idx == -1, 'padding_idx must be -1.' | |
self.prod_rule_corpus = prod_rule_corpus | |
self.layer2layer_activation = layer2layer_activation | |
self.layer2out_activation = layer2out_activation | |
self.out_dim = out_dim | |
self.element_embed_dim = element_embed_dim | |
self.num_layers = num_layers | |
self.padding_idx = padding_idx | |
self.use_gpu = use_gpu | |
self.layer2layer_list = [] | |
self.layer2out_list = [] | |
if self.use_gpu: | |
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda() | |
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda() | |
for _ in range(num_layers+1): | |
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda()) | |
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda()) | |
else: | |
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim) | |
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim) | |
for _ in range(num_layers+1): | |
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim)) | |
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim)) | |
def forward(self, prod_rule_idx_seq): | |
''' forward model for mini-batch | |
Parameters | |
---------- | |
prod_rule_idx_seq : (batch_size, length) | |
Returns | |
------- | |
Variable, shape (batch_size, length, out_dim) | |
''' | |
batch_size, length = prod_rule_idx_seq.shape | |
if self.use_gpu: | |
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda() | |
else: | |
out = Variable(torch.zeros((batch_size, length, self.out_dim))) | |
for each_batch_idx in range(batch_size): | |
for each_idx in range(length): | |
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list): | |
continue | |
else: | |
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])] | |
if self.use_gpu: | |
layer_wise_embed_dict = {each_edge: self.atom_embed( | |
Variable(torch.LongTensor( | |
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']] | |
), requires_grad=False).cuda()) | |
for each_edge in each_prod_rule.rhs.edges} | |
layer_wise_embed_dict.update({each_node: self.bond_embed( | |
Variable( | |
torch.LongTensor([ | |
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]), | |
requires_grad=False).cuda() | |
) for each_node in each_prod_rule.rhs.nodes}) | |
else: | |
layer_wise_embed_dict = {each_edge: self.atom_embed( | |
Variable(torch.LongTensor( | |
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']] | |
), requires_grad=False)) | |
for each_edge in each_prod_rule.rhs.edges} | |
layer_wise_embed_dict.update({each_node: self.bond_embed( | |
Variable( | |
torch.LongTensor([ | |
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]), | |
requires_grad=False) | |
) for each_node in each_prod_rule.rhs.nodes}) | |
for each_layer in range(self.num_layers): | |
next_layer_embed_dict = {} | |
for each_edge in each_prod_rule.rhs.edges: | |
v = layer_wise_embed_dict[each_edge] | |
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge): | |
v += layer_wise_embed_dict[each_node] | |
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) | |
for each_node in each_prod_rule.rhs.nodes: | |
v = layer_wise_embed_dict[each_node] | |
for each_edge in each_prod_rule.rhs.adj_edges(each_node): | |
v += layer_wise_embed_dict[each_edge] | |
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) | |
layer_wise_embed_dict = next_layer_embed_dict | |
for each_edge in each_prod_rule.rhs.edges: | |
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v)) | |
for each_edge in each_prod_rule.rhs.edges: | |
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v)) | |
return out | |
class MolecularProdRuleEmbeddingUsingFeatures(nn.Module): | |
''' molecular fingerprint layer | |
''' | |
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation, | |
out_dim=32, num_layers=3, padding_idx=None, use_gpu=False): | |
super().__init__() | |
if padding_idx is not None: | |
assert padding_idx == -1, 'padding_idx must be -1.' | |
self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors() | |
self.prod_rule_corpus = prod_rule_corpus | |
self.layer2layer_activation = layer2layer_activation | |
self.layer2out_activation = layer2out_activation | |
self.out_dim = out_dim | |
self.num_layers = num_layers | |
self.padding_idx = padding_idx | |
self.use_gpu = use_gpu | |
self.layer2layer_list = [] | |
self.layer2out_list = [] | |
if self.use_gpu: | |
for each_key in self.feature_dict: | |
self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda() | |
for _ in range(num_layers): | |
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda()) | |
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda()) | |
else: | |
for _ in range(num_layers): | |
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim)) | |
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim)) | |
def forward(self, prod_rule_idx_seq): | |
''' forward model for mini-batch | |
Parameters | |
---------- | |
prod_rule_idx_seq : (batch_size, length) | |
Returns | |
------- | |
Variable, shape (batch_size, length, out_dim) | |
''' | |
batch_size, length = prod_rule_idx_seq.shape | |
if self.use_gpu: | |
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda() | |
else: | |
out = Variable(torch.zeros((batch_size, length, self.out_dim))) | |
for each_batch_idx in range(batch_size): | |
for each_idx in range(length): | |
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list): | |
continue | |
else: | |
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])] | |
edge_list = sorted(list(each_prod_rule.rhs.edges)) | |
node_list = sorted(list(each_prod_rule.rhs.nodes)) | |
adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list))) | |
if self.use_gpu: | |
adj_mat = adj_mat.cuda() | |
layer_wise_embed = [ | |
self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']] | |
for each_edge in edge_list]\ | |
+ [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']] | |
for each_node in node_list] | |
for each_node in each_prod_rule.ext_node.values(): | |
layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \ | |
= layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \ | |
+ self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])] | |
layer_wise_embed = torch.stack(layer_wise_embed) | |
for each_layer in range(self.num_layers): | |
message = adj_mat @ layer_wise_embed | |
next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message)) | |
out[each_batch_idx, each_idx, :] \ | |
= out[each_batch_idx, each_idx, :] \ | |
+ self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0) | |
layer_wise_embed = next_layer_embed | |
return out | |