File size: 1,736 Bytes
6cf191b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

#############################
#   Imports
#############################

# Python modules
from collections import deque
from ast import literal_eval

# Remote modules
import torch

# Local modules

#############################
#   Constants
#############################

##########################################################
#   Helper functions for Relations in dict format
##########################################################

def clean_relations(word_relations):
    new_relations = deque()
    for r in word_relations:
        rel = {}
        for r_key, r_value in r.items():
            normal_k = literal_eval(r_key)
            rel_d = {}
            for r_d_key, r_d_value in r_value.items():
                normal_d_k = literal_eval(r_d_key)
                rel_d[normal_d_k] = r_d_value
            rel[normal_k] = rel_d
        new_relations.append(rel)
    list_new_relations = list(new_relations)
    return list_new_relations

##########################################################
#   Helper functions for Relations in Matrix format
##########################################################

def relation_binary_2d_to_1d(relations_binary_mask, dim=1):
    relations_binary_mask = relations_binary_mask.sum(dim=dim)
    relations_binary_mask[relations_binary_mask > 1] = 1
    return relations_binary_mask

def tokens_with_relations(relations_binary_mask):
    relations_binary_mask_dim1 = relations_binary_mask.sum(dim=0)
    relations_binary_mask_dim2 = relations_binary_mask.sum(dim=1)
    tokens_with_rels = relations_binary_mask_dim1 + relations_binary_mask_dim2
    tokens_with_rels[tokens_with_rels > 1] = 1
    mask_rels = torch.tensor(tokens_with_rels, dtype=torch.bool)
    return mask_rels