Spaces:
Runtime error
Runtime error
from typing import List, Iterable, Tuple | |
from functools import partial | |
import numpy as np | |
import torch | |
import json | |
from utils.token_processing import fix_byte_spaces | |
from utils.gen_utils import map_nlist | |
def round_return_value(attentions, ndigits=5): | |
"""Rounding must happen right before it's passed back to the frontend because there is a little numerical error that's introduced converting back to lists | |
attentions: { | |
'aa': { | |
left | |
right | |
att | |
} | |
} | |
""" | |
rounder = partial(round, ndigits=ndigits) | |
nested_rounder = partial(map_nlist, rounder) | |
new_out = attentions # Modify values to save memory | |
new_out["aa"]["att"] = nested_rounder(attentions["aa"]["att"]) | |
return new_out | |
def flatten_batch(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: | |
"""Remove the batch dimension of every tensor inside the Iterable container `x`""" | |
return tuple([x_.squeeze(0) for x_ in x]) | |
def squeeze_contexts(x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: | |
"""Combine the last two dimensions of the context.""" | |
shape = x[0].shape | |
new_shape = shape[:-2] + (-1,) | |
return tuple([x_.view(new_shape) for x_ in x]) | |
def add_blank(xs: Tuple[torch.tensor]) -> Tuple[torch.Tensor]: | |
"""The embeddings have n_layers + 1, indicating the final output embedding.""" | |
return (torch.zeros_like(xs[0]),) + xs | |
class TransformerOutputFormatter: | |
def __init__( | |
self, | |
sentence: str, | |
tokens: List[str], | |
special_tokens_mask: List[int], | |
att: Tuple[torch.Tensor], | |
topk_words: List[List[str]], | |
topk_probs: List[List[float]], | |
model_config | |
): | |
assert len(tokens) > 0, "Cannot have an empty token output!" | |
modified_att = flatten_batch(att) | |
self.sentence = sentence | |
self.tokens = tokens | |
self.special_tokens_mask = special_tokens_mask | |
self.attentions = modified_att | |
self.topk_words = topk_words | |
self.topk_probs = topk_probs | |
self.model_config = model_config | |
self.n_layer = self.model_config.n_layer | |
self.n_head = self.model_config.n_head | |
self.hidden_dim = self.model_config.n_embd | |
self.__len = len(tokens)# Get the number of tokens in the input | |
assert self.__len == self.attentions[0].shape[-1], "Attentions don't represent the passed tokens!" | |
def to_json(self, layer:int, ndigits=5): | |
"""The original API expects the following response: | |
aa: { | |
att: number[][][] | |
left: List[str] | |
right: List[str] | |
} | |
""" | |
# Convert the embeddings, attentions, and contexts into list. Perform rounding | |
rounder = partial(round, ndigits=ndigits) | |
nested_rounder = partial(map_nlist, rounder) | |
def tolist(tens): return [t.tolist() for t in tens] | |
def to_resp(tok: str, topk_words, topk_probs): | |
return { | |
"text": tok, | |
"topk_words": topk_words, | |
"topk_probs": nested_rounder(topk_probs) | |
} | |
side_info = [to_resp(t, w, p) for t,w,p in zip( self.tokens, | |
self.topk_words, | |
self.topk_probs)] | |
out = {"aa": { | |
"att": nested_rounder(tolist(self.attentions[layer])), | |
"left": side_info, | |
"right": side_info | |
}} | |
return out | |
def display_tokens(self, tokens): | |
return fix_byte_spaces(tokens) | |
def __repr__(self): | |
lim = 50 | |
if len(self.sentence) > lim: s = self.sentence[:lim - 3] + "..." | |
else: s = self.sentence[:lim] | |
return f"TransformerOutput({s})" | |
def __len__(self): | |
return self.__len | |
def to_numpy(x): | |
"""Embeddings, contexts, and attentions are stored as torch.Tensors in a tuple. Convert this to a numpy array | |
for storage in hdf5""" | |
return np.array([x_.detach().numpy() for x_ in x]) | |
def to_searchable(t: Tuple[torch.Tensor]): | |
return t.detach().numpy().astype(np.float32) |