File size: 1,736 Bytes
6cf191b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
#############################
# Imports
#############################
# Python modules
from collections import deque
from ast import literal_eval
# Remote modules
import torch
# Local modules
#############################
# Constants
#############################
##########################################################
# Helper functions for Relations in dict format
##########################################################
def clean_relations(word_relations):
new_relations = deque()
for r in word_relations:
rel = {}
for r_key, r_value in r.items():
normal_k = literal_eval(r_key)
rel_d = {}
for r_d_key, r_d_value in r_value.items():
normal_d_k = literal_eval(r_d_key)
rel_d[normal_d_k] = r_d_value
rel[normal_k] = rel_d
new_relations.append(rel)
list_new_relations = list(new_relations)
return list_new_relations
##########################################################
# Helper functions for Relations in Matrix format
##########################################################
def relation_binary_2d_to_1d(relations_binary_mask, dim=1):
relations_binary_mask = relations_binary_mask.sum(dim=dim)
relations_binary_mask[relations_binary_mask > 1] = 1
return relations_binary_mask
def tokens_with_relations(relations_binary_mask):
relations_binary_mask_dim1 = relations_binary_mask.sum(dim=0)
relations_binary_mask_dim2 = relations_binary_mask.sum(dim=1)
tokens_with_rels = relations_binary_mask_dim1 + relations_binary_mask_dim2
tokens_with_rels[tokens_with_rels > 1] = 1
mask_rels = torch.tensor(tokens_with_rels, dtype=torch.bool)
return mask_rels
|