from typing import List, Iterable, Tuple from functools import partial import numpy as np import torch import json from utils.token_processing import fix_byte_spaces from utils.gen_utils import map_nlist def round_return_value(attentions, ndigits=5): """Rounding must happen right before it's passed back to the frontend because there is a little numerical error that's introduced converting back to lists attentions: { 'aa': { left right att } } """ rounder = partial(round, ndigits=ndigits) nested_rounder = partial(map_nlist, rounder) new_out = attentions # Modify values to save memory new_out["aa"]["att"] = nested_rounder(attentions["aa"]["att"]) return new_out def flatten_batch(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: """Remove the batch dimension of every tensor inside the Iterable container `x`""" return tuple([x_.squeeze(0) for x_ in x]) def squeeze_contexts(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: """Combine the last two dimensions of the context.""" shape = x[0].shape new_shape = shape[:-2] + (-1,) return tuple([x_.view(new_shape) for x_ in x]) def add_blank(xs: Tuple[torch.tensor]) -> Tuple[torch.Tensor]: """The embeddings have n_layers + 1, indicating the final output embedding.""" return (torch.zeros_like(xs[0]),) + xs class TransformerOutputFormatter: def __init__( self, sentence: str, tokens: List[str], special_tokens_mask: List[int], att: Tuple[torch.Tensor], topk_words: List[List[str]], topk_probs: List[List[float]], model_config ): assert len(tokens) > 0, "Cannot have an empty token output!" modified_att = flatten_batch(att) self.sentence = sentence self.tokens = tokens self.special_tokens_mask = special_tokens_mask self.attentions = modified_att self.topk_words = topk_words self.topk_probs = topk_probs self.model_config = model_config self.n_layer = self.model_config.n_layer self.n_head = self.model_config.n_head self.hidden_dim = self.model_config.n_embd self.__len = len(tokens)# Get the number of tokens in the input assert self.__len == self.attentions[0].shape[-1], "Attentions don't represent the passed tokens!" def to_json(self, layer:int, ndigits=5): """The original API expects the following response: aa: { att: number[][][] left: List[str] right: List[str] } """ # Convert the embeddings, attentions, and contexts into list. Perform rounding rounder = partial(round, ndigits=ndigits) nested_rounder = partial(map_nlist, rounder) def tolist(tens): return [t.tolist() for t in tens] def to_resp(tok: str, topk_words, topk_probs): return { "text": tok, "topk_words": topk_words, "topk_probs": nested_rounder(topk_probs) } side_info = [to_resp(t, w, p) for t,w,p in zip( self.tokens, self.topk_words, self.topk_probs)] out = {"aa": { "att": nested_rounder(tolist(self.attentions[layer])), "left": side_info, "right": side_info }} return out def display_tokens(self, tokens): return fix_byte_spaces(tokens) def __repr__(self): lim = 50 if len(self.sentence) > lim: s = self.sentence[:lim - 3] + "..." else: s = self.sentence[:lim] return f"TransformerOutput({s})" def __len__(self): return self.__len def to_numpy(x): """Embeddings, contexts, and attentions are stored as torch.Tensors in a tuple. Convert this to a numpy array for storage in hdf5""" return np.array([x_.detach().numpy() for x_ in x]) def to_searchable(t: Tuple[torch.Tensor]): return t.detach().numpy().astype(np.float32)