Tom Aarsen commited on
Commit
841db35
·
1 Parent(s): 9e25f87

Initial commit; minus indices

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +86 -0
  3. requirements.txt +6 -0
  4. save_binary_index.py +13 -0
  5. save_int8_index.py +13 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wikipedia_int8_10k_usearch.index
2
+ wikipedia_ubinary_10k_faiss.index
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ import gradio as gr
4
+ from datasets import load_dataset
5
+ import pandas as pd
6
+ from sentence_transformers import SentenceTransformer
7
+ from sentence_transformers.util import quantize_embeddings
8
+ import faiss
9
+ from usearch.index import Index
10
+
11
+ # Load titles and texts
12
+ title_text_dataset = load_dataset("mixedbread-ai/wikipedia-2023-11-embed-en-pre-1", split="train").select_columns(["title", "text"])
13
+
14
+ # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
15
+ int8_view = Index.restore("wikipedia_int8_usearch_1m.index", view=True)
16
+ binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_1m.index")
17
+
18
+ # Load the SentenceTransformer model for embedding the queries
19
+ model = SentenceTransformer(
20
+ "mixedbread-ai/mxbai-embed-large-v1",
21
+ prompts={
22
+ "retrieval": "Represent this sentence for searching relevant passages: ",
23
+ },
24
+ default_prompt_name="retrieval",
25
+ )
26
+
27
+
28
+ def search(query, top_k: int = 10, rerank_multiplier: int = 4):
29
+ # 1. Embed the query as float32
30
+ start_time = time.time()
31
+ query_embedding = model.encode(query)
32
+ embed_time = time.time() - start_time
33
+
34
+ # 2. Quantize the query to ubinary
35
+ start_time = time.time()
36
+ query_embedding_ubinary = quantize_embeddings(query_embedding, "ubinary")
37
+ quantize_time = time.time() - start_time
38
+
39
+ # 3. Search the binary index
40
+ start_time = time.time()
41
+ _scores, binary_ids = binary_index.search(query_embedding_ubinary, top_k * rerank_multiplier)
42
+ binary_ids = binary_ids[0]
43
+ search_time = time.time() - start_time
44
+
45
+ # 4. Load the corresponding int8 embeddings
46
+ start_time = time.time()
47
+ int8_embeddings = int8_view[binary_ids].astype(int)
48
+ load_time = time.time() - start_time
49
+
50
+ # 5. Rerank the top_k * rerank_multiplier using the float32 query embedding and the int8 document embeddings
51
+ start_time = time.time()
52
+ scores = query_embedding @ int8_embeddings.T
53
+ rerank_time = time.time() - start_time
54
+
55
+ # 6. Sort the scores and return the top_k
56
+ start_time = time.time()
57
+ top_k_indices = (-scores).argsort()[-top_k:]
58
+ top_k_scores = scores[top_k_indices]
59
+ top_k_titles, top_k_texts = zip(*[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in binary_ids[top_k_indices].tolist()])
60
+ df = pd.DataFrame({"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts})
61
+ sort_time = time.time() - start_time
62
+
63
+ return df, {
64
+ "Embed Time": f"{embed_time:.4f} s",
65
+ "Quantize Time": f"{quantize_time:.4f} s",
66
+ "Search Time": f"{search_time:.4f} s",
67
+ "Load Time": f"{load_time:.4f} s",
68
+ "Rerank Time": f"{rerank_time:.4f} s",
69
+ "Sort Time": f"{sort_time:.4f} s",
70
+ "Total Retrieval Time": f"{quantize_time + search_time + load_time + rerank_time + sort_time:.4f} s"
71
+ }
72
+
73
+ with gr.Blocks(title="Quantized Retrieval") as demo:
74
+ query = gr.Textbox(label="Query")
75
+ search_button = gr.Button(value="Search")
76
+
77
+ with gr.Row():
78
+ with gr.Column(scale=4):
79
+ output = gr.Dataframe(column_widths=["10%", "20%", "80%"], headers=["Score", "Title", "Text"])
80
+ with gr.Column(scale=1):
81
+ json = gr.JSON()
82
+
83
+ search_button.click(search, inputs=[query], outputs=[output, json])
84
+
85
+ demo.queue()
86
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ sentence_transformers
2
+ datasets
3
+ pandas
4
+
5
+ usearch
6
+ faiss
save_binary_index.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import load_dataset
3
+ import numpy as np
4
+ from faiss import IndexBinaryFlat, write_index_binary
5
+ from sentence_transformers.util import quantize_embeddings
6
+
7
+ dataset = load_dataset("mixedbread-ai/wikipedia-2023-11-embed-en-pre-1", split="train")
8
+ embeddings = np.array(dataset["emb"], dtype=np.float32)
9
+
10
+ ubinary_embeddings = quantize_embeddings(embeddings, "ubinary")
11
+ index = IndexBinaryFlat(1024)
12
+ index.add(ubinary_embeddings)
13
+ write_index_binary(index, "wikipedia_ubinary_faiss_1m.index")
save_int8_index.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import load_dataset
3
+ import numpy as np
4
+ from usearch.index import Index
5
+ from sentence_transformers.util import quantize_embeddings
6
+
7
+ dataset = load_dataset("mixedbread-ai/wikipedia-2023-11-embed-en-pre-1", split="train")
8
+ embeddings = np.array(dataset["emb"], dtype=np.float32)
9
+
10
+ int8_embeddings = quantize_embeddings(embeddings, "int8")
11
+ index = Index(ndim=1024, metric="ip", dtype="i8")
12
+ index.add(np.arange(len(int8_embeddings)), int8_embeddings)
13
+ index.save("wikipedia_int8_usearch_1m.index")