File size: 3,388 Bytes
841db35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

import time
import gradio as gr
from datasets import load_dataset
import pandas as pd
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import quantize_embeddings
import faiss
from usearch.index import Index

# Load titles and texts
title_text_dataset = load_dataset("mixedbread-ai/wikipedia-2023-11-embed-en-pre-1", split="train").select_columns(["title", "text"])

# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
int8_view = Index.restore("wikipedia_int8_usearch_1m.index", view=True)
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_1m.index")

# Load the SentenceTransformer model for embedding the queries
model = SentenceTransformer(
    "mixedbread-ai/mxbai-embed-large-v1",
    prompts={
        "retrieval": "Represent this sentence for searching relevant passages: ",
    },
    default_prompt_name="retrieval",
)


def search(query, top_k: int = 10, rerank_multiplier: int = 4):
    # 1. Embed the query as float32
    start_time = time.time()
    query_embedding = model.encode(query)
    embed_time = time.time() - start_time

    # 2. Quantize the query to ubinary
    start_time = time.time()
    query_embedding_ubinary = quantize_embeddings(query_embedding, "ubinary")
    quantize_time = time.time() - start_time

    # 3. Search the binary index
    start_time = time.time()
    _scores, binary_ids = binary_index.search(query_embedding_ubinary, top_k * rerank_multiplier)
    binary_ids = binary_ids[0]
    search_time = time.time() - start_time

    # 4. Load the corresponding int8 embeddings
    start_time = time.time()
    int8_embeddings = int8_view[binary_ids].astype(int)
    load_time = time.time() - start_time

    # 5. Rerank the top_k * rerank_multiplier using the float32 query embedding and the int8 document embeddings
    start_time = time.time()
    scores = query_embedding @ int8_embeddings.T
    rerank_time = time.time() - start_time

    # 6. Sort the scores and return the top_k
    start_time = time.time()
    top_k_indices = (-scores).argsort()[-top_k:]
    top_k_scores = scores[top_k_indices]
    top_k_titles, top_k_texts = zip(*[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in binary_ids[top_k_indices].tolist()])
    df = pd.DataFrame({"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts})
    sort_time = time.time() - start_time

    return df, {
        "Embed Time": f"{embed_time:.4f} s",
        "Quantize Time": f"{quantize_time:.4f} s",
        "Search Time": f"{search_time:.4f} s",
        "Load Time": f"{load_time:.4f} s",
        "Rerank Time": f"{rerank_time:.4f} s",
        "Sort Time": f"{sort_time:.4f} s",
        "Total Retrieval Time": f"{quantize_time + search_time + load_time + rerank_time + sort_time:.4f} s"
    }

with gr.Blocks(title="Quantized Retrieval") as demo:
    query = gr.Textbox(label="Query")
    search_button = gr.Button(value="Search")

    with gr.Row():
        with gr.Column(scale=4):
            output = gr.Dataframe(column_widths=["10%", "20%", "80%"], headers=["Score", "Title", "Text"])
        with gr.Column(scale=1):
            json = gr.JSON()

    search_button.click(search, inputs=[query], outputs=[output, json])

demo.queue()
demo.launch(debug=True)