Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Rhizome | |
# Version beta 0.0, August 2023 | |
# Property of IBM Research, Accelerated Discovery | |
# | |
""" | |
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) | |
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. | |
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. | |
""" | |
""" Title """ | |
__author__ = "Hiroshi Kajino <[email protected]>" | |
__copyright__ = "(c) Copyright IBM Corp. 2018" | |
__version__ = "0.1" | |
__date__ = "Apr 18 2018" | |
from torch.utils.data import Dataset, DataLoader | |
import torch | |
import numpy as np | |
def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False): | |
''' pad left | |
Parameters | |
---------- | |
sentence_list : list of sequences of integers | |
max_len : int | |
maximum length of sentences. | |
if a sentence is shorter than `max_len`, its left part is padded. | |
pad_idx : int | |
integer for padding | |
inverse : bool | |
if True, the sequence is inversed. | |
Returns | |
------- | |
List of torch.LongTensor | |
each sentence is left-padded. | |
''' | |
max_in_list = max([len(each_sen) for each_sen in sentence_list]) | |
if max_in_list > max_len: | |
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list)) | |
if inverse: | |
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list] | |
else: | |
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list] | |
def right_padding(sentence_list, max_len, pad_idx=-1): | |
''' pad right | |
Parameters | |
---------- | |
sentence_list : list of sequences of integers | |
max_len : int | |
maximum length of sentences. | |
if a sentence is shorter than `max_len`, its right part is padded. | |
pad_idx : int | |
integer for padding | |
Returns | |
------- | |
List of torch.LongTensor | |
each sentence is right-padded. | |
''' | |
max_in_list = max([len(each_sen) for each_sen in sentence_list]) | |
if max_in_list > max_len: | |
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list)) | |
return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list] | |
class HRGDataset(Dataset): | |
''' | |
A class of HRG data | |
''' | |
def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False): | |
self.hrg = hrg | |
self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list, | |
max_len, | |
inverse=inversed_input) | |
self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len) | |
self.inserved_input = inversed_input | |
self.target_val_list = target_val_list | |
if target_val_list is not None: | |
if len(prod_rule_seq_list) != len(target_val_list): | |
raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}') | |
def __len__(self): | |
return len(self.left_prod_rule_seq_list) | |
def __getitem__(self, idx): | |
if self.target_val_list is not None: | |
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx]) | |
else: | |
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx] | |
def vocab_size(self): | |
return self.hrg.num_prod_rule | |
def batch_padding(each_batch, batch_size, padding_idx): | |
num_pad = batch_size - len(each_batch[0]) | |
if num_pad: | |
each_batch[0] = torch.cat([each_batch[0], | |
padding_idx * torch.ones((batch_size - len(each_batch[0]), | |
len(each_batch[0][0])), dtype=torch.int64)], dim=0) | |
each_batch[1] = torch.cat([each_batch[1], | |
padding_idx * torch.ones((batch_size - len(each_batch[1]), | |
len(each_batch[1][0])), dtype=torch.int64)], dim=0) | |
return each_batch, num_pad | |