|
"""
|
|
Adapted from comm2multilabel.py from the Bert-for-FrameNet project (https://gitlab.com/gosseminnema/bert-for-framenet)
|
|
"""
|
|
|
|
import dataclasses
|
|
import json
|
|
import os
|
|
import glob
|
|
import sys
|
|
from collections import defaultdict
|
|
from typing import List, Optional
|
|
|
|
import nltk
|
|
from concrete import Communication
|
|
from concrete.util import read_communication_from_file, lun, get_tokens
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FrameAnnotation:
|
|
tokens: List[str] = dataclasses.field(default_factory=list)
|
|
pos: List[str] = dataclasses.field(default_factory=list)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MultiLabelAnnotation(FrameAnnotation):
|
|
frame_list: List[List[str]] = dataclasses.field(default_factory=list)
|
|
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)
|
|
|
|
def to_txt(self):
|
|
for i, tok in enumerate(self.tokens):
|
|
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"
|
|
|
|
@staticmethod
|
|
def from_txt(sentence_lines):
|
|
|
|
tokens = []
|
|
pos = []
|
|
frame_list = []
|
|
lu_list = []
|
|
for line in sentence_lines:
|
|
|
|
|
|
if line.startswith(" "):
|
|
continue
|
|
|
|
columns = line.split()
|
|
tokens.append(columns[0])
|
|
pos.append(columns[1])
|
|
|
|
|
|
if columns[2] == "_":
|
|
frame_list.append([])
|
|
else:
|
|
frame_list.append(columns[2].split("|"))
|
|
|
|
|
|
if columns[3] == "_":
|
|
lu_list.append(None)
|
|
else:
|
|
lu_list.append(columns[3])
|
|
return MultiLabelAnnotation(tokens, pos, frame_list, lu_list)
|
|
|
|
def get_label_set(self):
|
|
label_set = set()
|
|
for tok_labels in self.frame_list:
|
|
for label in tok_labels:
|
|
label_set.add(label)
|
|
return label_set
|
|
|
|
|
|
def convert_file(file, language="english", confidence_filter=0.0):
|
|
print("Reading input file...")
|
|
comm = read_communication_from_file(file)
|
|
|
|
print("Mapping sentences to situations...")
|
|
tok_uuid_to_situation = map_sent_to_situation(comm)
|
|
|
|
print("# sentences with situations:", len(tok_uuid_to_situation))
|
|
|
|
for section in lun(comm.sectionList):
|
|
for sentence in lun(section.sentenceList):
|
|
tokens = get_tokens(sentence.tokenization)
|
|
situations = tok_uuid_to_situation[sentence.tokenization.uuid.uuidString]
|
|
tok_to_annos = map_tokens_to_annotations(comm, situations, confidence_filter)
|
|
|
|
frame_list, tok_list = prepare_ml_lists(language, tok_to_annos, tokens)
|
|
|
|
ml_anno = MultiLabelAnnotation(tok_list, ["_" for _ in tok_list], frame_list,
|
|
[None for _ in tok_list])
|
|
yield ml_anno
|
|
|
|
|
|
def prepare_ml_lists(language, tok_to_annos, tokens):
|
|
tok_list = []
|
|
frame_list = []
|
|
for tok_idx, tok in enumerate(tokens):
|
|
|
|
split_tok = nltk.word_tokenize(tok.text, language=language)
|
|
tok_list.extend(split_tok)
|
|
tok_anno = []
|
|
for anno in tok_to_annos.get(tok_idx, []):
|
|
tok_anno.append(anno)
|
|
frame_list.extend([list(tok_anno) for _ in split_tok])
|
|
|
|
|
|
for idx, (tok, frame_annos) in enumerate(zip(tok_list, frame_list)):
|
|
if tok in ",.:;\"'`«»":
|
|
to_delete = []
|
|
for fa in frame_annos:
|
|
if fa.startswith("T:"):
|
|
compare_fa = fa
|
|
else:
|
|
compare_fa = "I" + fa[1:]
|
|
|
|
if idx == len(tok_list) - 1:
|
|
to_delete.append(fa)
|
|
elif compare_fa not in frame_list[idx + 1]:
|
|
to_delete.append(fa)
|
|
|
|
for fa in to_delete:
|
|
frame_annos.remove(fa)
|
|
|
|
for fa_idx, fa in enumerate(frame_annos):
|
|
|
|
if fa.startswith("B:"):
|
|
|
|
if idx > 0 and fa in frame_list[idx - 1]:
|
|
frame_annos[fa_idx] = "I" + fa[1:]
|
|
|
|
return frame_list, tok_list
|
|
|
|
|
|
def map_tokens_to_annotations(comm: Communication, situations: List[str], confidence_filter: float):
|
|
tok_to_annos = defaultdict(list)
|
|
for sit_idx, sit_uuid in enumerate(situations):
|
|
situation = comm.situationMentionForUUID[sit_uuid]
|
|
if situation.confidence < confidence_filter:
|
|
continue
|
|
|
|
frame_type = situation.situationKind
|
|
tgt_tokens = situation.tokens.tokenIndexList
|
|
|
|
if frame_type == "@@VIRTUAL_ROOT@@":
|
|
continue
|
|
|
|
for tok_id in tgt_tokens:
|
|
tok_to_annos[tok_id].append(f"T:{frame_type}@{sit_idx:02}@@{situation.confidence}")
|
|
for arg in situation.argumentList:
|
|
if arg.confidence < confidence_filter:
|
|
continue
|
|
|
|
fe_type = arg.role
|
|
fe_tokens = arg.entityMention.tokens.tokenIndexList
|
|
for tok_n, tok_id in enumerate(fe_tokens):
|
|
if tok_n == 0:
|
|
bio = "B"
|
|
else:
|
|
bio = "I"
|
|
tok_to_annos[tok_id].append(f"{bio}:{frame_type}:{fe_type}@{sit_idx:02}@@{arg.confidence}")
|
|
return tok_to_annos
|
|
|
|
|
|
def map_sent_to_situation(comm):
|
|
tok_uuid_to_situation = defaultdict(list)
|
|
for situation in comm.situationMentionSetList:
|
|
for mention in situation.mentionList:
|
|
tok_uuid_to_situation[mention.tokens.tokenizationId.uuidString].append(mention.uuid.uuidString)
|
|
return tok_uuid_to_situation
|
|
|
|
|
|
def main():
|
|
file_in = sys.argv[1]
|
|
language = sys.argv[2]
|
|
output_directory = sys.argv[3]
|
|
confidence_filter = float(sys.argv[4])
|
|
split_by_migration_files = False
|
|
|
|
file_in_base = os.path.basename(file_in)
|
|
file_out = f"{output_directory}/lome_{file_in_base}"
|
|
multi_label_annos = list(convert_file(file_in, language=language, confidence_filter=confidence_filter))
|
|
multi_label_json = [dataclasses.asdict(anno) for anno in multi_label_annos]
|
|
|
|
if split_by_migration_files:
|
|
files = glob.glob("output/migration/split_data/split_dev10_sep_txt_files/*.orig.txt")
|
|
files.sort(key=lambda f: int(f.split("/")[-1].rstrip(".orig.txt")))
|
|
|
|
for anno, file in zip(multi_label_annos, files):
|
|
basename = file.split("/")[-1].rstrip(".orig.txt")
|
|
spl_file_out = f"{output_directory}/{basename}"
|
|
with open(f"{spl_file_out}.txt", "w", encoding="utf-8") as f_txt:
|
|
for line in anno.to_txt():
|
|
f_txt.write(line + os.linesep)
|
|
f_txt.write(os.linesep)
|
|
|
|
else:
|
|
print(file_out)
|
|
with open(f"{file_out}.json", "w", encoding="utf-8") as f_json:
|
|
json.dump(multi_label_json, f_json, indent=4)
|
|
|
|
with open(f"{file_out}.txt", "w", encoding="utf-8") as f_txt:
|
|
for anno in multi_label_annos:
|
|
for line in anno.to_txt():
|
|
f_txt.write(line + os.linesep)
|
|
f_txt.write(os.linesep)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|