exbert / server /data_processing /index_wrapper.py
bhoov's picture
First commit
history blame
3.37 kB
from functools import partial
import faiss
import numpy as np
from pathlib import Path
from typing import Iterable
from utils.f import memoize
from transformers import AutoConfig
def get_config(model_name):
return AutoConfig.from_pretrained(model_name)
FAISS_LAYER_PATTERN = 'layer_*.faiss'
LAYER_TEMPLATE = 'layer_{:02d}.faiss'
def create_mask(head_size:int , n_heads:int, selected_heads:Iterable[int]):
"""Create a masked vector of size (head_size * n_heads), where 0 indicates we don't care about the contribution of that head 1 indicates that we do care
head_size: Hidden dimension of the heads
n_heads: Number of heads the model has
selected_heads: Which heads we don't want to zero out
mask = np.zeros(n_heads)
for h in selected_heads:
mask[int(h)] = 1
return np.repeat(mask, head_size)
class Indexes:
"""Wrapper around the faiss indices to make searching for a vector simpler and faster.
Assumes there are files in the folder matching the pattern input
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
self.base_dir = Path(folder)
self.n_layers = len(list(self.base_dir.glob(pattern))) - 1 # Subtract final output
self.indexes = [None] * (self.n_layers + 1) # Initialize empty list, adding 1 for input
self.pattern = pattern
# Extract model name from folder hierarchy
self.model_name = self.base_dir.parent.parent.stem
self.config = get_config(self.model_name)
self.nheads = self.config.num_attention_heads
self.hidden_size = self.config.hidden_size
assert (self.hidden_size % self.nheads) == 0, "Number of heads does not divide cleanly into the hidden size. Aborting"
self.head_size = int(self.config.hidden_size / self.nheads)
def __getitem__(self, v):
"""Slices not allowed, but index only"""
return self.indexes[v]
def __init_indexes(self):
for fname in self.base_dir.glob(self.pattern):
idx = fname.stem.split('_')[-1]
self.indexes[int(idx)] = faiss.read_index(str(fname))
def search(self, layer, query, k):
"""Search a given layer for the query vector. Return k results"""
return self[layer].search(query, k)
class ContextIndexes(Indexes):
"""Special index enabling masking of particular heads before searching"""
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
super().__init__(folder, pattern)
self.head_mask = partial(create_mask, self.head_size, self.nheads)
# Int -> [Int] -> np.Array -> Int -> (np.Array(), )
def search(self, layer:int, heads:list, query:np.ndarray, k:int):
"""Search the embeddings for the context layer, masking by selected heads"""
assert max(heads) < self.nheads, "max of selected heads must be lest than nheads. Are you indexing by 1 instead of 0?"
assert min(heads) >= 0, "What is a negative head?"
unique_heads = list(set(heads))
mask_vector = self.head_mask(unique_heads)
mask_vector = mask_vector.reshape(query.shape)
new_query = (query * mask_vector).astype(np.float32)
return self[layer].search(new_query, k)