Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
|
|
1 |
import faiss
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
5 |
import torch
|
6 |
-
from transformers import AutoModel, AutoTokenizer
|
7 |
|
8 |
import os
|
9 |
|
@@ -34,38 +34,7 @@ def load_sentence_embeddings():
|
|
34 |
return sentence_embeddings
|
35 |
|
36 |
|
37 |
-
@st.cache
|
38 |
-
def build_faiss_index(sentence_emeddings):
|
39 |
-
D = 768
|
40 |
-
N = 789188
|
41 |
-
Xt = sentence_emeddings[:39000]
|
42 |
-
X = sentence_emeddings
|
43 |
-
|
44 |
-
# Param of PQ
|
45 |
-
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
|
46 |
-
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
|
47 |
-
# Param of IVF
|
48 |
-
nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
|
49 |
-
# Param of HNSW
|
50 |
-
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
|
51 |
-
|
52 |
-
# Setup
|
53 |
-
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
|
54 |
-
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
|
55 |
-
|
56 |
-
# Train
|
57 |
-
index.train(Xt)
|
58 |
-
|
59 |
-
# Add
|
60 |
-
index.add(X)
|
61 |
-
|
62 |
-
# Search
|
63 |
-
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
|
64 |
-
|
65 |
-
return index
|
66 |
-
|
67 |
-
|
68 |
-
@st.cache(allow_output_mutation=True)
|
69 |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
|
70 |
with torch.no_grad():
|
71 |
inputs = tokenizer.encode_plus(
|
@@ -102,9 +71,34 @@ def main(model, tokenizer, sentence_df, index):
|
|
102 |
if __name__ == "__main__":
|
103 |
model, tokenizer = load_model_and_tokenizer()
|
104 |
sentence_df = load_sentence_data()
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
main(model, tokenizer, sentence_df, index)
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
import streamlit as st
|
6 |
import torch
|
|
|
7 |
|
8 |
import os
|
9 |
|
|
|
34 |
return sentence_embeddings
|
35 |
|
36 |
|
37 |
+
@st.cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
|
39 |
with torch.no_grad():
|
40 |
inputs = tokenizer.encode_plus(
|
|
|
71 |
if __name__ == "__main__":
|
72 |
model, tokenizer = load_model_and_tokenizer()
|
73 |
sentence_df = load_sentence_data()
|
74 |
+
sentence_embeddings = load_sentence_embeddings()
|
75 |
+
|
76 |
+
faiss.normalize_L2(sentence_embeddings)
|
77 |
+
|
78 |
+
D = 768
|
79 |
+
N = 789188
|
80 |
+
Xt = sentence_embeddings[:39000]
|
81 |
+
X = sentence_embeddings
|
82 |
|
83 |
+
# Param of PQ
|
84 |
+
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
|
85 |
+
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
|
86 |
+
# Param of IVF
|
87 |
+
nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
|
88 |
+
# Param of HNSW
|
89 |
+
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
|
90 |
+
|
91 |
+
# Setup
|
92 |
+
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
|
93 |
+
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
|
94 |
+
|
95 |
+
# Train
|
96 |
+
index.train(Xt)
|
97 |
+
|
98 |
+
# Add
|
99 |
+
index.add(X)
|
100 |
+
|
101 |
+
# Search
|
102 |
+
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
|
103 |
|
104 |
main(model, tokenizer, sentence_df, index)
|