Spaces:
Sleeping
Sleeping
| """ | |
| Parsing formulated as span classification (https://arxiv.org/abs/1705.03919) | |
| """ | |
| import nltk | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch_struct | |
| from .parse_base import CompressedParserOutput | |
| def pad_charts(charts, padding_value=-100): | |
| """Pad a list of variable-length charts with `padding_value`.""" | |
| batch_size = len(charts) | |
| max_len = max(chart.shape[0] for chart in charts) | |
| padded_charts = torch.full( | |
| (batch_size, max_len, max_len), | |
| padding_value, | |
| dtype=charts[0].dtype, | |
| device=charts[0].device, | |
| ) | |
| for i, chart in enumerate(charts): | |
| chart_size = chart.shape[0] | |
| padded_charts[i, :chart_size, :chart_size] = chart | |
| return padded_charts | |
| def collapse_unary_strip_pos(tree, strip_top=True): | |
| """Collapse unary chains and strip part of speech tags.""" | |
| def strip_pos(tree): | |
| if len(tree) == 1 and isinstance(tree[0], str): | |
| return tree[0] | |
| else: | |
| return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree]) | |
| collapsed_tree = strip_pos(tree) | |
| collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::") | |
| if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"): | |
| if strip_top: | |
| if len(collapsed_tree) == 1: | |
| collapsed_tree = collapsed_tree[0] | |
| else: | |
| collapsed_tree.set_label("") | |
| elif len(collapsed_tree) == 1: | |
| collapsed_tree[0].set_label( | |
| collapsed_tree.label() + "::" + collapsed_tree[0].label()) | |
| collapsed_tree = collapsed_tree[0] | |
| return collapsed_tree | |
| def _get_labeled_spans(tree, spans_out, start): | |
| if isinstance(tree, str): | |
| return start + 1 | |
| assert len(tree) > 1 or isinstance( | |
| tree[0], str | |
| ), "Must call collapse_unary_strip_pos first" | |
| end = start | |
| for child in tree: | |
| end = _get_labeled_spans(child, spans_out, end) | |
| # Spans are returned as closed intervals on both ends | |
| spans_out.append((start, end - 1, tree.label())) | |
| return end | |
| def get_labeled_spans(tree): | |
| """Converts a tree into a list of labeled spans. | |
| Args: | |
| tree: an nltk.tree.Tree object | |
| Returns: | |
| A list of (span_start, span_end, span_label) tuples. The start and end | |
| indices indicate the first and last words of the span (a closed | |
| interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will | |
| result in a single span labeled "S+VP". | |
| """ | |
| tree = collapse_unary_strip_pos(tree) | |
| spans_out = [] | |
| _get_labeled_spans(tree, spans_out, start=0) | |
| return spans_out | |
| def uncollapse_unary(tree, ensure_top=False): | |
| """Un-collapse unary chains.""" | |
| if isinstance(tree, str): | |
| return tree | |
| else: | |
| labels = tree.label().split("::") | |
| if ensure_top and labels[0] != "TOP": | |
| labels = ["TOP"] + labels | |
| children = [] | |
| for child in tree: | |
| child = uncollapse_unary(child) | |
| children.append(child) | |
| for label in labels[::-1]: | |
| children = [nltk.tree.Tree(label, children)] | |
| return children[0] | |
| class ChartDecoder: | |
| """A chart decoder for parsing formulated as span classification.""" | |
| def __init__(self, label_vocab, force_root_constituent=True): | |
| """Constructs a new ChartDecoder object. | |
| Args: | |
| label_vocab: A mapping from span labels to integer indices. | |
| """ | |
| self.label_vocab = label_vocab | |
| self.label_from_index = {i: label for label, i in label_vocab.items()} | |
| self.force_root_constituent = force_root_constituent | |
| def build_vocab(trees): | |
| label_set = set() | |
| for tree in trees: | |
| for _, _, label in get_labeled_spans(tree): | |
| if label: | |
| label_set.add(label) | |
| label_set = [""] + sorted(label_set) | |
| return {label: i for i, label in enumerate(label_set)} | |
| def infer_force_root_constituent(trees): | |
| for tree in trees: | |
| for _, _, label in get_labeled_spans(tree): | |
| if not label: | |
| return False | |
| return True | |
| def chart_from_tree(self, tree): | |
| spans = get_labeled_spans(tree) | |
| num_words = len(tree.leaves()) | |
| chart = np.full((num_words, num_words), -100, dtype=int) | |
| chart = np.tril(chart, -1) | |
| # Now all invalid entries are filled with -100, and valid entries with 0 | |
| for start, end, label in spans: | |
| # Previously unseen unary chains can occur in the dev/test sets. | |
| # For now, we ignore them and don't mark the corresponding chart | |
| # entry as a constituent. | |
| if label in self.label_vocab: | |
| chart[start, end] = self.label_vocab[label] | |
| return chart | |
| def charts_from_pytorch_scores_batched(self, scores, lengths): | |
| """Runs CKY to recover span labels from scores (e.g. logits). | |
| This method uses pytorch-struct to speed up decoding compared to the | |
| pure-Python implementation of CKY used by tree_from_scores(). | |
| Args: | |
| scores: a pytorch tensor of shape (batch size, max length, | |
| max length, label vocab size). | |
| lengths: a pytorch tensor of shape (batch size,) | |
| Returns: | |
| A list of numpy arrays, each of shape (sentence length, sentence | |
| length). | |
| """ | |
| scores = scores.detach() | |
| scores = scores - scores[..., :1] | |
| if self.force_root_constituent: | |
| scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9 | |
| dist = torch_struct.TreeCRF(scores, lengths=lengths) | |
| amax = dist.argmax | |
| amax[..., 0] += 1e-9 | |
| padded_charts = amax.argmax(-1) | |
| padded_charts = padded_charts.detach().cpu().numpy() | |
| return [ | |
| chart[:length, :length] for chart, length in zip(padded_charts, lengths) | |
| ] | |
| def compressed_output_from_chart(self, chart): | |
| chart_with_filled_diagonal = chart.copy() | |
| np.fill_diagonal(chart_with_filled_diagonal, 1) | |
| chart_with_filled_diagonal[0, -1] = 1 | |
| starts, inclusive_ends = np.where(chart_with_filled_diagonal) | |
| preorder_sort = np.lexsort((-inclusive_ends, starts)) | |
| starts = starts[preorder_sort] | |
| inclusive_ends = inclusive_ends[preorder_sort] | |
| labels = chart[starts, inclusive_ends] | |
| ends = inclusive_ends + 1 | |
| return CompressedParserOutput(starts=starts, ends=ends, labels=labels) | |
| def tree_from_chart(self, chart, leaves): | |
| compressed_output = self.compressed_output_from_chart(chart) | |
| return compressed_output.to_tree(leaves, self.label_from_index) | |
| def tree_from_scores(self, scores, leaves): | |
| """Runs CKY to decode a tree from scores (e.g. logits). | |
| If speed is important, consider using charts_from_pytorch_scores_batched | |
| followed by compressed_output_from_chart or tree_from_chart instead. | |
| Args: | |
| scores: a chart of scores (or logits) of shape | |
| (sentence length, sentence length, label vocab size). The first | |
| two dimensions may be padded to a longer length, but all padded | |
| values will be ignored. | |
| leaves: the leaf nodes to use in the constructed tree. These | |
| may be of type str or nltk.Tree, or (word, tag) tuples that | |
| will be used to construct the leaf node objects. | |
| Returns: | |
| An nltk.Tree object. | |
| """ | |
| leaves = [ | |
| nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node | |
| for node in leaves | |
| ] | |
| chart = {} | |
| scores = scores - scores[:, :, 0, None] | |
| for length in range(1, len(leaves) + 1): | |
| for left in range(0, len(leaves) + 1 - length): | |
| right = left + length | |
| label_scores = scores[left, right - 1] | |
| label_scores = label_scores - label_scores[0] | |
| argmax_label_index = int( | |
| label_scores.argmax() | |
| if length < len(leaves) or not self.force_root_constituent | |
| else label_scores[1:].argmax() + 1 | |
| ) | |
| argmax_label = self.label_from_index[argmax_label_index] | |
| label = argmax_label | |
| label_score = label_scores[argmax_label_index] | |
| if length == 1: | |
| tree = leaves[left] | |
| if label: | |
| tree = nltk.tree.Tree(label, [tree]) | |
| chart[left, right] = [tree], label_score | |
| continue | |
| best_split = max( | |
| range(left + 1, right), | |
| key=lambda split: (chart[left, split][1] + chart[split, right][1]), | |
| ) | |
| left_trees, left_score = chart[left, best_split] | |
| right_trees, right_score = chart[best_split, right] | |
| children = left_trees + right_trees | |
| if label: | |
| children = [nltk.tree.Tree(label, children)] | |
| chart[left, right] = (children, label_score + left_score + right_score) | |
| children, score = chart[0, len(leaves)] | |
| tree = nltk.tree.Tree("TOP", children) | |
| tree = uncollapse_unary(tree) | |
| return tree | |
| class SpanClassificationMarginLoss(nn.Module): | |
| def __init__(self, force_root_constituent=True, reduction="mean"): | |
| super().__init__() | |
| self.force_root_constituent = force_root_constituent | |
| if reduction not in ("none", "mean", "sum"): | |
| raise ValueError(f"Invalid value for reduction: {reduction}") | |
| self.reduction = reduction | |
| def forward(self, logits, labels): | |
| gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1]) | |
| logits = logits - logits[..., :1] | |
| lengths = (labels[:, 0, :] != -100).sum(-1) | |
| augment = (1 - gold_event).to(torch.float) | |
| if self.force_root_constituent: | |
| augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9 | |
| dist = torch_struct.TreeCRF(logits + augment, lengths=lengths) | |
| pred_score = dist.max | |
| gold_score = (logits * gold_event).sum((1, 2, 3)) | |
| margin_losses = F.relu(pred_score - gold_score) | |
| if self.reduction == "none": | |
| return margin_losses | |
| elif self.reduction == "mean": | |
| return margin_losses.mean() | |
| elif self.reduction == "sum": | |
| return margin_losses.sum() | |
| else: | |
| assert False, f"Unexpected reduction: {self.reduction}" | |