|
import os |
|
import re |
|
from typing import Dict, Iterable, List, Optional, Tuple |
|
import json |
|
import random |
|
import argparse |
|
from allennlp.data.fields.field import Field |
|
from allennlp.data.fields.sequence_field import SequenceField |
|
from allennlp.models.model import Model |
|
from allennlp.nn.util import get_text_field_mask |
|
from allennlp.predictors.predictor import Predictor |
|
|
|
import pandas as pd |
|
import spacy |
|
import torch |
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
|
|
from allennlp.common.util import pad_sequence_to_length |
|
from allennlp.data import TextFieldTensors |
|
from allennlp.data.vocabulary import Vocabulary |
|
from allennlp.data import DatasetReader, TokenIndexer, Instance, Token |
|
from allennlp.data.fields import TextField, LabelField |
|
from allennlp.data.token_indexers.pretrained_transformer_indexer import ( |
|
PretrainedTransformerIndexer, |
|
) |
|
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import ( |
|
PretrainedTransformerTokenizer, |
|
) |
|
from allennlp.models import BasicClassifier |
|
from allennlp.modules.text_field_embedders.basic_text_field_embedder import ( |
|
BasicTextFieldEmbedder, |
|
) |
|
from allennlp.modules.token_embedders.pretrained_transformer_embedder import ( |
|
PretrainedTransformerEmbedder, |
|
) |
|
from allennlp.modules.seq2vec_encoders.bert_pooler import BertPooler |
|
from allennlp.modules.seq2vec_encoders.cls_pooler import ClsPooler |
|
from allennlp.training.checkpointer import Checkpointer |
|
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer |
|
from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader |
|
from allennlp.training.optimizers import AdamOptimizer |
|
from allennlp.predictors.text_classifier import TextClassifierPredictor |
|
from allennlp.training.callbacks.tensorboard import TensorBoardCallback |
|
from torch import nn |
|
from torch.nn.functional import binary_cross_entropy_with_logits |
|
|
|
|
|
random.seed(1986) |
|
|
|
|
|
SEQ_LABELS = ["humansMentioned", "vehiclesMentioned", "eventVerb", "activeEventVerb"] |
|
|
|
|
|
|
|
class SequenceMultiLabelField(Field): |
|
|
|
def __init__(self, |
|
labels: List[List[str]], |
|
sequence_field: SequenceField, |
|
binarizer: MultiLabelBinarizer, |
|
label_namespace: str |
|
): |
|
self.labels = labels |
|
self._indexed_labels = None |
|
self._label_namespace = label_namespace |
|
self.sequence_field = sequence_field |
|
self.binarizer = binarizer |
|
|
|
@staticmethod |
|
def retokenize_tags(tags: List[List[str]], |
|
offsets: List[Tuple[int, int]], |
|
wp_primary_token: str = "last", |
|
wp_secondary_tokens: str = "empty", |
|
empty_value=lambda: [] |
|
) -> List[List[str]]: |
|
tags_per_wordpiece = [ |
|
empty_value() |
|
] |
|
|
|
for i, (off_start, off_end) in enumerate(offsets): |
|
tag = tags[i] |
|
|
|
|
|
|
|
|
|
num_extra_tokens = off_end - off_start |
|
if wp_primary_token == "first": |
|
tags_per_wordpiece.append(tag) |
|
if wp_secondary_tokens == "repeat": |
|
tags_per_wordpiece.extend(num_extra_tokens * [tag]) |
|
else: |
|
tags_per_wordpiece.extend(num_extra_tokens * [empty_value()]) |
|
if wp_primary_token == "last": |
|
tags_per_wordpiece.append(tag) |
|
|
|
tags_per_wordpiece.append(empty_value()) |
|
|
|
return tags_per_wordpiece |
|
|
|
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): |
|
for label_list in self.labels: |
|
for label in label_list: |
|
counter[self._label_namespace][label] += 1 |
|
|
|
def get_padding_lengths(self) -> Dict[str, int]: |
|
return {"num_tokens": self.sequence_field.sequence_length()} |
|
|
|
def index(self, vocab: Vocabulary): |
|
|
|
indexed_labels: List[List[int]] = [] |
|
for sentence_labels in self.labels: |
|
sentence_indexed_labels = [] |
|
for label in sentence_labels: |
|
try: |
|
sentence_indexed_labels.append( |
|
vocab.get_token_index(label, self._label_namespace)) |
|
except KeyError: |
|
print(f"[WARNING] Ignore unknown label {label}") |
|
indexed_labels.append(sentence_indexed_labels) |
|
self._indexed_labels = indexed_labels |
|
|
|
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: |
|
|
|
|
|
binarized_seq = self.binarizer.transform(self._indexed_labels).tolist() |
|
|
|
|
|
desired_num_tokens = padding_lengths["num_tokens"] |
|
padded_tags = pad_sequence_to_length(binarized_seq, desired_num_tokens, |
|
default_value=lambda: list(self.binarizer.transform([[]])[0])) |
|
|
|
tensor = torch.tensor(padded_tags, dtype=torch.float) |
|
return tensor |
|
|
|
def empty_field(self) -> 'Field': |
|
|
|
field = SequenceMultiLabelField( |
|
[], self.sequence_field.empty_field(), self.binarizer, self._label_namespace) |
|
field._indexed_labels = [] |
|
return field |
|
|
|
|
|
|
|
class MultiSequenceLabelModel(Model): |
|
|
|
def __init__(self, embedder: PretrainedTransformerEmbedder, decoder_output_size: int, hidden_size: int, vocab: Vocabulary, embedding_size: int = 768): |
|
super().__init__(vocab) |
|
self.embedder = embedder |
|
self.out_features = decoder_output_size |
|
self.hidden_size = hidden_size |
|
self.layers = nn.Sequential( |
|
nn.Linear(in_features=embedding_size, |
|
out_features=self.hidden_size), |
|
nn.ReLU(), |
|
nn.Linear(in_features=self.hidden_size, |
|
out_features=self.out_features) |
|
) |
|
|
|
def forward(self, tokens: TextFieldTensors, label: Optional[torch.FloatTensor] = None): |
|
embeddings = self.embedder(tokens["token_ids"]) |
|
mask = get_text_field_mask(tokens).float() |
|
tag_logits = self.layers(embeddings) |
|
mask = mask.reshape(mask.shape[0], mask.shape[1], 1).repeat(1, 1, self.out_features) |
|
output = {"tag_logits": tag_logits} |
|
if label is not None: |
|
loss = binary_cross_entropy_with_logits(tag_logits, label, mask) |
|
output["loss"] = loss |
|
|
|
def get_metrics(self, _) -> Dict[str, float]: |
|
return {} |
|
|
|
def make_human_readable(self, |
|
prediction, |
|
label_namespace, |
|
threshold=0.2, |
|
sigmoid=True |
|
) -> Tuple[List[str], Optional[List[float]]]: |
|
if sigmoid: |
|
prediction = torch.sigmoid(prediction) |
|
|
|
predicted_labels: List[List[str]] = [[] for _ in range(len(prediction))] |
|
|
|
|
|
for coord in torch.nonzero(prediction > threshold): |
|
label = self.vocab.get_token_from_index(int(coord[1]), label_namespace) |
|
predicted_labels[coord[0]].append(f"{label}:{prediction[coord[0], coord[1]]:.3f}") |
|
|
|
str_predictions: List[str] = [] |
|
for label_list in predicted_labels: |
|
str_predictions.append("|".join(label_list) or "_") |
|
|
|
return str_predictions |
|
|
|
|
|
class TrafficBechdelReader(DatasetReader): |
|
|
|
def __init__(self, token_indexers, tokenizer, binarizer): |
|
self.token_indexers = token_indexers |
|
self.tokenizer: PretrainedTransformerTokenizer = tokenizer |
|
self.binarizer = binarizer |
|
self.orig_data = [] |
|
super().__init__() |
|
|
|
def _read(self, file_path) -> Iterable[Instance]: |
|
self.orig_data.clear() |
|
|
|
with open(file_path, encoding="utf-8") as f: |
|
for line in f: |
|
|
|
if not line.strip(): |
|
continue |
|
|
|
sentence_parts = line.lstrip("[").rstrip("]").split(",") |
|
token_txts = [] |
|
token_mlabels = [] |
|
|
|
for sp in sentence_parts: |
|
sp_txt, sp_lbl_str = sp.split(":") |
|
if sp_lbl_str == "[]": |
|
sp_lbls = [] |
|
else: |
|
sp_lbls = sp_lbl_str.lstrip("[").rstrip("]").split("|") |
|
|
|
|
|
wn_match = re.match(r"^(.+)-n-\d+$", sp_txt) |
|
if wn_match: |
|
sp_txt = wn_match.group(1) |
|
|
|
|
|
sp_toks = sp_txt.split() |
|
for tok in sp_toks: |
|
token_txts.append(tok) |
|
token_mlabels.append(sp_lbls) |
|
|
|
self.orig_data.append({ |
|
"sentence": token_txts, |
|
"labels": token_mlabels, |
|
}) |
|
yield self.text_to_instance(token_txts, token_mlabels) |
|
|
|
def text_to_instance(self, sentence: List[str], labels: List[List[str]] = None) -> Instance: |
|
tokens, offsets = self.tokenizer.intra_word_tokenize(sentence) |
|
|
|
text_field = TextField(tokens, self.token_indexers) |
|
fields = {"tokens": text_field} |
|
if labels is not None: |
|
labels_ = SequenceMultiLabelField.retokenize_tags(labels, offsets) |
|
label_field = SequenceMultiLabelField(labels_, text_field, self.binarizer, "labels") |
|
fields["label"] = label_field |
|
return Instance(fields) |
|
|
|
|
|
def count_parties(sentence, lexical_dicts, nlp): |
|
|
|
num_humans = 0 |
|
num_vehicles = 0 |
|
|
|
def is_in_words(l, category): |
|
for subcategory, words in lexical_dicts[category].items(): |
|
if subcategory.startswith("WN:"): |
|
words = [re.match(r"^(.+)-n-\d+$", w).group(1) for w in words] |
|
if l in words: |
|
return True |
|
return False |
|
|
|
doc = nlp(sentence.lower()) |
|
for token in doc: |
|
lemma = token.lemma_ |
|
if is_in_words(lemma, "persons"): |
|
num_humans += 1 |
|
if is_in_words(lemma, "vehicles"): |
|
num_vehicles += 1 |
|
|
|
return num_humans, num_vehicles |
|
|
|
|
|
def predict_rule_based(annotations="data/crashes/bechdel_annotations_dev_first_25.csv"): |
|
data_crashes = pd.read_csv(annotations) |
|
with open("output/crashes/predict_bechdel/lexical_dicts.json", encoding="utf-8") as f: |
|
lexical_dicts = json.load(f) |
|
|
|
nlp = spacy.load("nl_core_news_md") |
|
|
|
for _, row in data_crashes.iterrows(): |
|
sentence = row["sentence"] |
|
num_humans, num_vehicles = count_parties(sentence, lexical_dicts, nlp) |
|
print(sentence) |
|
print(f"\thumans={num_humans}, vehicles={num_vehicles}") |
|
|
|
|
|
def evaluate_crashes(predictor, attrib, annotations="data/crashes/bechdel_annotations_dev_first_25.csv", out_file="output/crashes/predict_bechdel/predictions_crashes25.csv"): |
|
data_crashes = pd.read_csv(annotations) |
|
labels_crashes = [ |
|
{ |
|
"party_mentioned": str(row["mentioned"]), |
|
"party_human": str(row["as_human"]), |
|
"active": str(True) if str(row["active"]).lower() == "true" else str(False) |
|
} |
|
for _, row in data_crashes.iterrows() |
|
] |
|
predictions_crashes = [predictor.predict( |
|
row["sentence"]) for i, row in data_crashes.iterrows()] |
|
crashes_out = [] |
|
correct = 0 |
|
partial_2_attrs = 0 |
|
partial_1_attr = 0 |
|
correct_mentions = 0 |
|
correct_humans = 0 |
|
correct_active = 0 |
|
|
|
for sentence, label, prediction in zip(data_crashes["sentence"], labels_crashes, predictions_crashes): |
|
predicted = prediction["label"] |
|
if attrib == "all": |
|
gold = "|".join([f"{k}={v}" for k, v in label.items()]) |
|
else: |
|
gold = label["attrib"] |
|
if gold == predicted: |
|
correct += 1 |
|
if attrib == "all": |
|
partial_2_attrs += 1 |
|
partial_1_attr += 1 |
|
|
|
if attrib == "all": |
|
gold_attrs = set(gold.split("|")) |
|
pred_attrs = set(predicted.split("|")) |
|
if len(gold_attrs & pred_attrs) == 2: |
|
partial_2_attrs += 1 |
|
partial_1_attr += 1 |
|
elif len(gold_attrs & pred_attrs) == 1: |
|
partial_1_attr += 1 |
|
|
|
if gold.split("|")[0] == predicted.split("|")[0]: |
|
correct_mentions += 1 |
|
if gold.split("|")[1] == predicted.split("|")[1]: |
|
correct_humans += 1 |
|
if gold.split("|")[2] == predicted.split("|")[2]: |
|
correct_active += 1 |
|
|
|
crashes_out.append( |
|
{"sentence": sentence, "gold": gold, "prediction": predicted}) |
|
|
|
print("ACC_crashes (strict) = ", correct/len(data_crashes)) |
|
print("ACC_crashes (partial:2) = ", partial_2_attrs/len(data_crashes)) |
|
print("ACC_crashes (partial:1) = ", partial_1_attr/len(data_crashes)) |
|
print("ACC_crashes (mentions) = ", correct_mentions/len(data_crashes)) |
|
print("ACC_crashes (humans) = ", correct_humans/len(data_crashes)) |
|
print("ACC_crashes (active) = ", correct_active/len(data_crashes)) |
|
|
|
pd.DataFrame(crashes_out).to_csv(out_file) |
|
|
|
|
|
def filter_events_for_bechdel(): |
|
|
|
with open("data/crashes/thecrashes_data_all_text.json", encoding="utf-8") as f: |
|
events = json.load(f) |
|
|
|
total_articles = 0 |
|
data_out = [] |
|
for ev in events: |
|
total_articles += len(ev["articles"]) |
|
|
|
num_persons = len(ev["persons"]) |
|
num_transport_modes = len({p["transportationmode"] |
|
for p in ev["persons"]}) |
|
|
|
if num_transport_modes <= 2: |
|
for art in ev["articles"]: |
|
data_out.append({"event_id": ev["id"], "article_id": art["id"], "headline": art["title"], |
|
"num_persons": num_persons, "num_transport_modes": num_transport_modes}) |
|
|
|
print("Total articles = ", total_articles) |
|
|
|
print("Filtered articles: ", len(data_out)) |
|
out_df = pd.DataFrame(data_out) |
|
out_df.to_csv("output/crashes/predict_bechdel/filtered_headlines.csv") |
|
|
|
|
|
def train_and_eval(train=True): |
|
|
|
|
|
use_gpu = True |
|
cuda_device = None if use_gpu and torch.cuda.is_available() else -1 |
|
|
|
transformer = "GroNLP/bert-base-dutch-cased" |
|
|
|
token_indexers = {"tokens": PretrainedTransformerIndexer(transformer)} |
|
tokenizer = PretrainedTransformerTokenizer(transformer) |
|
|
|
binarizer = MultiLabelBinarizer() |
|
binarizer.fit([SEQ_LABELS]) |
|
reader = TrafficBechdelReader(token_indexers, tokenizer, binarizer) |
|
instances = list(reader.read("output/prolog/bechdel_headlines.txt")) |
|
orig_data = reader.orig_data |
|
zipped = list(zip(instances, orig_data)) |
|
random.shuffle(zipped) |
|
instances_ = [i[0] for i in zipped] |
|
orig_data_ = [i[1] for i in zipped] |
|
|
|
num_dev = round(0.05 * len(instances_)) |
|
num_test = round(0.25 * len(instances_)) |
|
num_train = len(instances_) - num_dev - num_test |
|
print("LEN(train/dev/test)=", num_train, num_dev, num_test) |
|
|
|
instances_train = instances_[:num_train] |
|
instances_dev = instances_[num_train:num_train + num_dev] |
|
|
|
|
|
|
|
orig_dev = orig_data_[num_train:num_train + num_dev] |
|
|
|
vocab = Vocabulary.from_instances(instances_train + instances_dev) |
|
|
|
embedder = BasicTextFieldEmbedder( |
|
{"tokens": PretrainedTransformerEmbedder(transformer)}) |
|
model = MultiSequenceLabelModel(embedder, len(SEQ_LABELS), 1000, vocab) |
|
if use_gpu: |
|
model = model.cuda(cuda_device) |
|
|
|
|
|
checkpoint_dir = f"/scratch/p289731/predict_bechdel/model_seqlabel/" |
|
serialization_dir = f"/scratch/p289731/predict_bechdel/serialization_seqlabel/" |
|
|
|
if train: |
|
os.makedirs(checkpoint_dir) |
|
os.makedirs(serialization_dir) |
|
tensorboard = TensorBoardCallback( |
|
serialization_dir, should_log_learning_rate=True) |
|
checkpointer = Checkpointer(serialization_dir=checkpoint_dir) |
|
optimizer = AdamOptimizer( |
|
[(n, p) for n, p in model.named_parameters() if p.requires_grad], |
|
lr=1e-5 |
|
) |
|
train_loader = SimpleDataLoader( |
|
instances_train, batch_size=8, shuffle=True) |
|
dev_loader = SimpleDataLoader( |
|
instances_dev, batch_size=8, shuffle=False) |
|
train_loader.index_with(vocab) |
|
dev_loader.index_with(vocab) |
|
|
|
print("\t\tTraining BERT model") |
|
trainer = GradientDescentTrainer( |
|
model, |
|
optimizer, |
|
train_loader, |
|
validation_data_loader=dev_loader, |
|
|
|
patience=2, |
|
|
|
checkpointer=checkpointer, |
|
cuda_device=cuda_device, |
|
serialization_dir=serialization_dir, |
|
callbacks=[tensorboard] |
|
) |
|
trainer.train() |
|
else: |
|
state_dict = torch.load( |
|
"/scratch/p289731/predict_bechdel/serialization_all/best.th", map_location=cuda_device) |
|
model.load_state_dict(state_dict) |
|
|
|
print("\t\tProducing predictions...") |
|
|
|
predictor = Predictor(model, reader) |
|
predictions_dev = [predictor.predict_instance(i) for i in instances_dev] |
|
|
|
data_out = [] |
|
for sentence, prediction in zip(orig_dev, predictions_dev): |
|
readable = model.make_human_readable(prediction, "labels") |
|
text = sentence["sentence"] |
|
gold = sentence["labels"] |
|
predicted = readable |
|
data_out.append( |
|
{"sentence": text, "gold": gold, "predicted": predicted}) |
|
df_out = pd.DataFrame(data_out) |
|
df_out.to_csv("output/crashes/predict_bechdel/predictions_dev.csv") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
ap = argparse.ArgumentParser() |
|
ap.add_argument("action", choices=["train", "predict", "rules", "filter"]) |
|
|
|
args = ap.parse_args() |
|
|
|
if args.action == "train": |
|
train_and_eval(train=True) |
|
elif args.action == "predict": |
|
train_and_eval(train=False) |
|
elif args.action == "rules": |
|
predict_rule_based() |
|
else: |
|
filter_events_for_bechdel() |
|
|