exbert / server /data_processing /index_wrapper.py
bhoov's picture
First commit
63858e7
raw
history blame
3.37 kB
from functools import partial
import faiss
import numpy as np
from pathlib import Path
from typing import Iterable
from utils.f import memoize
from transformers import AutoConfig
@memoize
def get_config(model_name):
return AutoConfig.from_pretrained(model_name)
FAISS_LAYER_PATTERN = 'layer_*.faiss'
LAYER_TEMPLATE = 'layer_{:02d}.faiss'
def create_mask(head_size:int , n_heads:int, selected_heads:Iterable[int]):
"""Create a masked vector of size (head_size * n_heads), where 0 indicates we don't care about the contribution of that head 1 indicates that we do care
Parameters:
-----------
head_size: Hidden dimension of the heads
n_heads: Number of heads the model has
selected_heads: Which heads we don't want to zero out
"""
mask = np.zeros(n_heads)
for h in selected_heads:
mask[int(h)] = 1
return np.repeat(mask, head_size)
class Indexes:
"""Wrapper around the faiss indices to make searching for a vector simpler and faster.
Assumes there are files in the folder matching the pattern input
"""
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
self.base_dir = Path(folder)
self.n_layers = len(list(self.base_dir.glob(pattern))) - 1 # Subtract final output
self.indexes = [None] * (self.n_layers + 1) # Initialize empty list, adding 1 for input
self.pattern = pattern
self.__init_indexes()
# Extract model name from folder hierarchy
self.model_name = self.base_dir.parent.parent.stem
self.config = get_config(self.model_name)
self.nheads = self.config.num_attention_heads
self.hidden_size = self.config.hidden_size
assert (self.hidden_size % self.nheads) == 0, "Number of heads does not divide cleanly into the hidden size. Aborting"
self.head_size = int(self.config.hidden_size / self.nheads)
def __getitem__(self, v):
"""Slices not allowed, but index only"""
return self.indexes[v]
def __init_indexes(self):
for fname in self.base_dir.glob(self.pattern):
print(fname)
idx = fname.stem.split('_')[-1]
self.indexes[int(idx)] = faiss.read_index(str(fname))
def search(self, layer, query, k):
"""Search a given layer for the query vector. Return k results"""
return self[layer].search(query, k)
class ContextIndexes(Indexes):
"""Special index enabling masking of particular heads before searching"""
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
super().__init__(folder, pattern)
self.head_mask = partial(create_mask, self.head_size, self.nheads)
# Int -> [Int] -> np.Array -> Int -> (np.Array(), )
def search(self, layer:int, heads:list, query:np.ndarray, k:int):
"""Search the embeddings for the context layer, masking by selected heads"""
assert max(heads) < self.nheads, "max of selected heads must be lest than nheads. Are you indexing by 1 instead of 0?"
assert min(heads) >= 0, "What is a negative head?"
unique_heads = list(set(heads))
mask_vector = self.head_mask(unique_heads)
mask_vector = mask_vector.reshape(query.shape)
new_query = (query * mask_vector).astype(np.float32)
return self[layer].search(new_query, k)