|
|
|
|
|
|
|
|
|
|
|
import json |
|
from functools import lru_cache |
|
|
|
|
|
def convert_sentence_to_json(sentence): |
|
if "_" in sentence: |
|
prefix, rest = sentence.split("_", 1) |
|
query, rest = rest.split("_", 1) |
|
query_index = len(prefix.rstrip().split(" ")) |
|
else: |
|
query, query_index = None, None |
|
|
|
prefix, rest = sentence.split("[", 1) |
|
pronoun, rest = rest.split("]", 1) |
|
pronoun_index = len(prefix.rstrip().split(" ")) |
|
|
|
sentence = sentence.replace("_", "").replace("[", "").replace("]", "") |
|
|
|
return { |
|
"idx": 0, |
|
"text": sentence, |
|
"target": { |
|
"span1_index": query_index, |
|
"span1_text": query, |
|
"span2_index": pronoun_index, |
|
"span2_text": pronoun, |
|
}, |
|
} |
|
|
|
|
|
def extended_noun_chunks(sentence): |
|
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks} |
|
np_start, cur_np = 0, "NONE" |
|
for i, token in enumerate(sentence): |
|
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE" |
|
if np_type != cur_np: |
|
if cur_np != "NONE": |
|
noun_chunks.add((np_start, i)) |
|
if np_type != "NONE": |
|
np_start = i |
|
cur_np = np_type |
|
if cur_np != "NONE": |
|
noun_chunks.add((np_start, len(sentence))) |
|
return [sentence[s:e] for (s, e) in sorted(noun_chunks)] |
|
|
|
|
|
def find_token(sentence, start_pos): |
|
found_tok = None |
|
for tok in sentence: |
|
if tok.idx == start_pos: |
|
found_tok = tok |
|
break |
|
return found_tok |
|
|
|
|
|
def find_span(sentence, search_text, start=0): |
|
search_text = search_text.lower() |
|
for tok in sentence[start:]: |
|
remainder = sentence[tok.i :].text.lower() |
|
if remainder.startswith(search_text): |
|
len_to_consume = len(search_text) |
|
start_idx = tok.idx |
|
for next_tok in sentence[tok.i :]: |
|
end_idx = next_tok.idx + len(next_tok.text) |
|
if end_idx - start_idx == len_to_consume: |
|
span = sentence[tok.i : next_tok.i + 1] |
|
return span |
|
return None |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def get_detokenizer(): |
|
from sacremoses import MosesDetokenizer |
|
|
|
detok = MosesDetokenizer(lang="en") |
|
return detok |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def get_spacy_nlp(): |
|
import en_core_web_lg |
|
|
|
nlp = en_core_web_lg.load() |
|
return nlp |
|
|
|
|
|
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False): |
|
detok = get_detokenizer() |
|
nlp = get_spacy_nlp() |
|
|
|
with open(input_fname) as fin: |
|
for line in fin: |
|
sample = json.loads(line.strip()) |
|
|
|
if positive_only and "label" in sample and not sample["label"]: |
|
|
|
continue |
|
|
|
target = sample["target"] |
|
|
|
|
|
query = target["span1_text"] |
|
if query is not None: |
|
if "\n" in query: |
|
continue |
|
if query.endswith(".") or query.endswith(","): |
|
query = query[:-1] |
|
|
|
|
|
tokens = sample["text"].split(" ") |
|
|
|
def strip_pronoun(x): |
|
return x.rstrip('.,"') |
|
|
|
|
|
pronoun_idx = target["span2_index"] |
|
pronoun = strip_pronoun(target["span2_text"]) |
|
if strip_pronoun(tokens[pronoun_idx]) != pronoun: |
|
|
|
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun: |
|
pronoun_idx += 1 |
|
else: |
|
raise Exception("Misaligned pronoun!") |
|
assert strip_pronoun(tokens[pronoun_idx]) == pronoun |
|
|
|
|
|
before = tokens[:pronoun_idx] |
|
after = tokens[pronoun_idx + 1 :] |
|
|
|
|
|
|
|
leading_space = " " if pronoun_idx > 0 else "" |
|
trailing_space = " " if len(after) > 0 else "" |
|
|
|
|
|
before = detok.detokenize(before, return_str=True) |
|
pronoun = detok.detokenize([pronoun], return_str=True) |
|
after = detok.detokenize(after, return_str=True) |
|
|
|
|
|
|
|
if pronoun.endswith(".") or pronoun.endswith(","): |
|
after = pronoun[-1] + trailing_space + after |
|
pronoun = pronoun[:-1] |
|
|
|
|
|
|
|
if after.startswith(".") or after.startswith(","): |
|
trailing_space = "" |
|
|
|
|
|
sentence = nlp(before + leading_space + pronoun + trailing_space + after) |
|
|
|
|
|
start = len(before + leading_space) |
|
first_pronoun_tok = find_token(sentence, start_pos=start) |
|
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i) |
|
assert pronoun_span.text == pronoun |
|
|
|
if eval: |
|
|
|
|
|
query_span = find_span(sentence, query) |
|
query_with_ws = "_{}_{}".format( |
|
query_span.text, |
|
(" " if query_span.text_with_ws.endswith(" ") else ""), |
|
) |
|
pronoun_with_ws = "[{}]{}".format( |
|
pronoun_span.text, |
|
(" " if pronoun_span.text_with_ws.endswith(" ") else ""), |
|
) |
|
if query_span.start < pronoun_span.start: |
|
first = (query_span, query_with_ws) |
|
second = (pronoun_span, pronoun_with_ws) |
|
else: |
|
first = (pronoun_span, pronoun_with_ws) |
|
second = (query_span, query_with_ws) |
|
sentence = ( |
|
sentence[: first[0].start].text_with_ws |
|
+ first[1] |
|
+ sentence[first[0].end : second[0].start].text_with_ws |
|
+ second[1] |
|
+ sentence[second[0].end :].text |
|
) |
|
yield sentence, sample.get("label", None) |
|
else: |
|
yield sentence, pronoun_span, query, sample.get("label", None) |
|
|
|
|
|
def winogrande_jsonl_iterator(input_fname, eval=False): |
|
with open(input_fname) as fin: |
|
for line in fin: |
|
sample = json.loads(line.strip()) |
|
sentence, option1, option2 = ( |
|
sample["sentence"], |
|
sample["option1"], |
|
sample["option2"], |
|
) |
|
|
|
pronoun_span = (sentence.index("_"), sentence.index("_") + 1) |
|
|
|
if eval: |
|
query, cand = option1, option2 |
|
else: |
|
query = option1 if sample["answer"] == "1" else option2 |
|
cand = option2 if sample["answer"] == "1" else option1 |
|
yield sentence, pronoun_span, query, cand |
|
|
|
|
|
def filter_noun_chunks( |
|
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False |
|
): |
|
if exclude_pronouns: |
|
chunks = [ |
|
np |
|
for np in chunks |
|
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np)) |
|
] |
|
|
|
if exclude_query is not None: |
|
excl_txt = [exclude_query.lower()] |
|
filtered_chunks = [] |
|
for chunk in chunks: |
|
lower_chunk = chunk.text.lower() |
|
found = False |
|
for excl in excl_txt: |
|
if ( |
|
not exact_match and (lower_chunk in excl or excl in lower_chunk) |
|
) or lower_chunk == excl: |
|
found = True |
|
break |
|
if not found: |
|
filtered_chunks.append(chunk) |
|
chunks = filtered_chunks |
|
|
|
return chunks |
|
|