Spaces:
Build error
Build error
# -*- coding:utf-8 -*- | |
""" | |
@Last modified date : 2020/12/23 | |
""" | |
import re | |
import nltk | |
from nltk.stem import WordNetLemmatizer | |
from allennlp.predictors.predictor import Predictor | |
nltk.download('wordnet') | |
nltk.download('stopwords') | |
def deal_bracket(text, restore, leading_ent=None): | |
if leading_ent: | |
leading_ent = ' '.join(leading_ent.split('_')) | |
text = f'Things about {leading_ent}: ' + text | |
if restore: | |
text = text.replace('-LRB-', '(').replace('-RRB-', ')') | |
text = text.replace('LRB', '(').replace('RRB', ')') | |
return text | |
def refine_entity(entity): | |
entity = re.sub(r'-LRB- .+ -RRB-$', '', entity) | |
entity = re.sub(r'LRB .+ RRB$', '', entity) | |
entity = re.sub(r'_', ' ', entity) | |
entity = re.sub(r'\s+', ' ', entity) | |
return entity.strip() | |
def find_sub_seq(seq_a, seq_b, shift=0, uncased=False, lemmatizer=None): | |
if uncased: | |
seq_a = [token.lower() for token in seq_a] | |
seq_b = [token.lower() for token in seq_b] | |
if lemmatizer is not None: | |
seq_a = [lemmatizer.lemmatize(token) for token in seq_a] | |
seq_b = [lemmatizer.lemmatize(token) for token in seq_b] | |
for i in range(shift, len(seq_a)): | |
if seq_a[i:i+len(seq_b)] == seq_b: | |
return i, i + len(seq_b) | |
return -1, -1 | |
def is_sub_seq(seq_start, seq_end, all_seqs): | |
for start, end, is_candidate in all_seqs: | |
if start <= seq_start < seq_end <= end: | |
return start, end, is_candidate | |
return None | |
# extract named entity with B-I-L-U-O schema | |
def extract_named_entity(tags): | |
all_NEs = [] | |
ne_type, ne_start = '', -1 | |
for i, t in enumerate(tags): | |
if t == 'O': | |
ne_type, ne_start = '', -1 | |
continue | |
t1, t2 = t.split('-') | |
if t1 == 'B': | |
ne_type, ne_start = t2, i | |
elif t1 == 'I' and t2 != ne_type: | |
ne_type, ne_start = '', -1 | |
elif t1 == 'L' and t2 != ne_type: | |
ne_type, ne_start = '', -1 | |
elif t1 == 'L' and t2 == ne_type: | |
all_NEs.append((ne_start, i + 1, False)) | |
ne_type, ne_start = '', -1 | |
elif t1 == 'U': | |
all_NEs.append((i, i + 1, False)) | |
ne_type, ne_start = '', -1 | |
return all_NEs | |
def refine_results(tokens, spans, stopwords): | |
all_spans = [] | |
for span_start, span_end, is_candidate in spans: | |
# remove stopwords | |
if not is_candidate: | |
while span_start < span_end and tokens[span_start].lower() in stopwords: | |
span_start += 1 | |
if span_start >= span_end: | |
continue | |
# add prefix | |
if span_start > 0 and tokens[span_start - 1] in ['a', 'an', 'A', 'An', 'the', 'The']: | |
span_start -= 1 | |
# convert token-level index into char-level index | |
span = ' '.join(tokens[span_start:span_end]) | |
span_start = len(' '.join(tokens[:span_start])) + 1 * min(1, span_start) # 1 for blank | |
span_end = span_start + len(span) | |
all_spans.append((span, span_start, span_end)) | |
all_spans = sorted(all_spans, key=lambda x: (x[1], x[1] - x[2])) | |
# remove overlap | |
refined_spans = [] | |
for span, span_start, span_end in all_spans: | |
flag = True | |
for _, start, end in refined_spans: | |
if start <= span_start < span_end <= end: | |
flag = False | |
break | |
if flag: | |
refined_spans.append((span, span_start, span_end)) | |
return refined_spans | |
class SentenceParser: | |
def __init__(self, device='cuda:0', | |
ner_path="https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz", | |
cp_path="https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"): | |
self.device = self.parse_device(device) | |
self.ner = Predictor.from_path(ner_path, cuda_device=self.device) | |
print('* ner loaded') | |
self.cp = Predictor.from_path(cp_path, cuda_device=self.device) | |
print('* constituency parser loaded') | |
self.lemmatizer = WordNetLemmatizer() | |
# some heuristic rules can be added here | |
self.stopwords = set(nltk.corpus.stopwords.words('english')) | |
self.stopwords.update({'-', '\'s', 'try', 'tries', 'tried', 'trying', | |
'become', 'becomes', 'became', 'becoming', | |
'make', 'makes', 'made', 'making', 'call', 'called', 'calling', | |
'put', 'ever', 'something', 'someone', 'sometime'}) | |
self.special_tokens = ['only', 'most', 'before', 'after', 'behind'] | |
for token in self.special_tokens: | |
if token in self.stopwords: self.stopwords.remove(token) | |
if 'won' in self.stopwords: self.stopwords.remove('won') | |
if 'own' in self.stopwords: self.stopwords.remove('own') | |
def parse_device(self, device): | |
if 'cpu' in device: | |
return -1 | |
else: | |
dev = re.findall('\d+', device) | |
return 0 if len(dev) == 0 else int(dev[0]) | |
def identify_NPs(self, text, candidate_NPs=None): | |
text = re.sub(r'\s+', ' ', text).strip() | |
if len(text) == 0: return {'text': '', 'NPs': [], 'verbs': [], 'adjs': []} | |
cp_outputs = self.cp.predict(text) | |
ner_outputs = self.ner.predict(text) | |
tokens = cp_outputs['tokens'] | |
pos_tags = cp_outputs['pos_tags'] | |
ner_tags = ner_outputs['tags'] | |
tree = cp_outputs['hierplane_tree']['root'] | |
# extract candidate noun phrases passed by user with token index | |
all_NPs = [] | |
candidate_NPs = [refine_entity(np).split() for np in candidate_NPs] if candidate_NPs else [] | |
for np in sorted(candidate_NPs, key=len, reverse=True): | |
np_start, np_end = find_sub_seq(tokens, np, 0, uncased=True, lemmatizer=self.lemmatizer) | |
while np_start != -1 and np_end != -1: | |
if not is_sub_seq(np_start, np_end, all_NPs): | |
all_NPs.append((np_start, np_end, True)) | |
np_start, np_end = find_sub_seq(tokens, np, np_end, uncased=True, lemmatizer=self.lemmatizer) | |
# extract noun phrases from tree | |
def _get_bottom_NPs(children): | |
if 'children' not in children: | |
return None | |
if {'NP', 'OP', 'XP', 'QP'} & set(children['attributes']): | |
is_bottom = True | |
for child in children['children']: | |
if 'children' in child: | |
is_bottom = False | |
if is_bottom: | |
bottom_NPs.append(children['word'].split()) | |
else: | |
for child in children['children']: | |
_get_bottom_NPs(child) | |
else: | |
for child in children['children']: | |
_get_bottom_NPs(child) | |
bottom_NPs = [] | |
_get_bottom_NPs(tree) | |
# find token indices of noun phrases | |
np_index = -1 | |
for np in bottom_NPs: | |
np_start, np_end = find_sub_seq(tokens, np, np_index + 1) | |
if not is_sub_seq(np_start, np_end, all_NPs): | |
all_NPs.append((np_start, np_end, False)) | |
np_index = np_end | |
# extract named entities with token index | |
all_NEs = extract_named_entity(ner_tags) | |
# extract verbs with token index | |
all_verbs = [] | |
for i, pos in enumerate(pos_tags): | |
if pos[0] == 'V': | |
if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): | |
all_verbs.append((i, i + 1, False)) | |
# extract modifiers with token index | |
all_modifiers = [] | |
for i, (token, pos) in enumerate(zip(tokens, pos_tags)): | |
if pos in ['JJ', 'RB']: # adj. and adv. | |
if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): | |
all_modifiers.append((i, i + 1, False)) | |
elif token in self.special_tokens: | |
if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): | |
all_modifiers.append((i, i + 1, False)) | |
# split noun phrases with named entities | |
all_spans = [] | |
for np_start, np_end, np_is_candidate in all_NPs: | |
if np_is_candidate: # candidate noun phrases will be preserved | |
all_spans.append((np_start, np_end, np_is_candidate)) | |
else: | |
match = is_sub_seq(np_start, np_end, all_NEs) | |
if match: # if a noun phrase is a sub span of a named entity, the named entity will be preserved | |
all_spans.append(match) | |
else: # else if a named entity is a sub span of a noun phrase, the noun phrase will be split | |
index = np_start | |
for ne_start, ne_end, ne_is_candidate in all_NEs: | |
if np_start <= ne_start < ne_end <= np_end: | |
all_modifiers.append((index, ne_start, False)) | |
all_spans.append((ne_start, ne_end, ne_is_candidate)) | |
index = ne_end | |
all_spans.append((index, np_end, False)) | |
# named entities without overlapping | |
for ne_start, ne_end, is_candidate in all_NEs: | |
if not is_sub_seq(ne_start, ne_end, all_spans): | |
all_spans.append((ne_start, ne_end, is_candidate)) | |
all_spans = refine_results(tokens, all_spans, self.stopwords) | |
all_verbs = refine_results(tokens, all_verbs, self.stopwords) | |
all_modifiers = refine_results(tokens, all_modifiers, self.stopwords) | |
return {'text': tree['word'], 'NPs': all_spans, 'verbs': all_verbs, 'adjs': all_modifiers} | |
if __name__ == '__main__': | |
import json | |
print('Initializing sentence parser.') | |
client = SentenceParser(device='cpu') | |
print('Parsing sentence.') | |
sentence = "The Africa Cup of Nations is held in odd - numbered years due to conflict with the World Cup . " | |
entities = ['Africa Cup of Nations', 'Africa_Cup_of_Nations', 'Africa Cup', 'Africa_Cup'] | |
results = client.identify_NPs(sentence, entities) | |
print(json.dumps(results, ensure_ascii=False, indent=4)) | |
# import random | |
# from tqdm import tqdm | |
# from utils import read_json_lines, save_json | |
# | |
# print('Parsing file.') | |
# results = [] | |
# data = list(read_json_lines('data/train.jsonl')) | |
# random.shuffle(data) | |
# for entry in tqdm(data[:100]): | |
# results.append(client.identify_NPs(entry['claim'])) | |
# save_json(results, 'data/results.json') | |