Spaces:
Runtime error
Runtime error
from functools import partial | |
import faiss | |
import numpy as np | |
from pathlib import Path | |
from typing import Iterable | |
from utils.f import memoize | |
from transformers import AutoConfig | |
def get_config(model_name): | |
return AutoConfig.from_pretrained(model_name) | |
FAISS_LAYER_PATTERN = 'layer_*.faiss' | |
LAYER_TEMPLATE = 'layer_{:02d}.faiss' | |
def create_mask(head_size:int , n_heads:int, selected_heads:Iterable[int]): | |
"""Create a masked vector of size (head_size * n_heads), where 0 indicates we don't care about the contribution of that head 1 indicates that we do care | |
Parameters: | |
----------- | |
head_size: Hidden dimension of the heads | |
n_heads: Number of heads the model has | |
selected_heads: Which heads we don't want to zero out | |
""" | |
mask = np.zeros(n_heads) | |
for h in selected_heads: | |
mask[int(h)] = 1 | |
return np.repeat(mask, head_size) | |
class Indexes: | |
"""Wrapper around the faiss indices to make searching for a vector simpler and faster. | |
Assumes there are files in the folder matching the pattern input | |
""" | |
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN): | |
self.base_dir = Path(folder) | |
self.n_layers = len(list(self.base_dir.glob(pattern))) - 1 # Subtract final output | |
self.indexes = [None] * (self.n_layers + 1) # Initialize empty list, adding 1 for input | |
self.pattern = pattern | |
self.__init_indexes() | |
# Extract model name from folder hierarchy | |
self.model_name = self.base_dir.parent.parent.stem | |
self.config = get_config(self.model_name) | |
self.nheads = self.config.num_attention_heads | |
self.hidden_size = self.config.hidden_size | |
assert (self.hidden_size % self.nheads) == 0, "Number of heads does not divide cleanly into the hidden size. Aborting" | |
self.head_size = int(self.config.hidden_size / self.nheads) | |
def __getitem__(self, v): | |
"""Slices not allowed, but index only""" | |
return self.indexes[v] | |
def __init_indexes(self): | |
for fname in self.base_dir.glob(self.pattern): | |
print(fname) | |
idx = fname.stem.split('_')[-1] | |
self.indexes[int(idx)] = faiss.read_index(str(fname)) | |
def search(self, layer, query, k): | |
"""Search a given layer for the query vector. Return k results""" | |
return self[layer].search(query, k) | |
class ContextIndexes(Indexes): | |
"""Special index enabling masking of particular heads before searching""" | |
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN): | |
super().__init__(folder, pattern) | |
self.head_mask = partial(create_mask, self.head_size, self.nheads) | |
# Int -> [Int] -> np.Array -> Int -> (np.Array(), ) | |
def search(self, layer:int, heads:list, query:np.ndarray, k:int): | |
"""Search the embeddings for the context layer, masking by selected heads""" | |
assert max(heads) < self.nheads, "max of selected heads must be lest than nheads. Are you indexing by 1 instead of 0?" | |
assert min(heads) >= 0, "What is a negative head?" | |
unique_heads = list(set(heads)) | |
mask_vector = self.head_mask(unique_heads) | |
mask_vector = mask_vector.reshape(query.shape) | |
new_query = (query * mask_vector).astype(np.float32) | |
return self[layer].search(new_query, k) | |