kaisugi commited on
Commit
b6363d9
·
1 Parent(s): 05f1914
Files changed (1) hide show
  1. app.py +30 -36
app.py CHANGED
@@ -1,9 +1,9 @@
 
1
  import faiss
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
5
  import torch
6
- from transformers import AutoModel, AutoTokenizer
7
 
8
  import os
9
 
@@ -34,38 +34,7 @@ def load_sentence_embeddings():
34
  return sentence_embeddings
35
 
36
 
37
- @st.cache(allow_output_mutation=True)
38
- def build_faiss_index(sentence_emeddings):
39
- D = 768
40
- N = 789188
41
- Xt = sentence_emeddings[:39000]
42
- X = sentence_emeddings
43
-
44
- # Param of PQ
45
- M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
46
- nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
47
- # Param of IVF
48
- nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
49
- # Param of HNSW
50
- hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
51
-
52
- # Setup
53
- quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
54
- index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
55
-
56
- # Train
57
- index.train(Xt)
58
-
59
- # Add
60
- index.add(X)
61
-
62
- # Search
63
- index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
64
-
65
- return index
66
-
67
-
68
- @st.cache(allow_output_mutation=True)
69
  def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
70
  with torch.no_grad():
71
  inputs = tokenizer.encode_plus(
@@ -102,9 +71,34 @@ def main(model, tokenizer, sentence_df, index):
102
  if __name__ == "__main__":
103
  model, tokenizer = load_model_and_tokenizer()
104
  sentence_df = load_sentence_data()
105
- sentence_emeddings = load_sentence_embeddings()
 
 
 
 
 
 
 
106
 
107
- faiss.normalize_L2(sentence_emeddings)
108
- index = build_faiss_index(sentence_emeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  main(model, tokenizer, sentence_df, index)
 
1
+ from transformers import AutoModel, AutoTokenizer
2
  import faiss
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
6
  import torch
 
7
 
8
  import os
9
 
 
34
  return sentence_embeddings
35
 
36
 
37
+ @st.cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df):
39
  with torch.no_grad():
40
  inputs = tokenizer.encode_plus(
 
71
  if __name__ == "__main__":
72
  model, tokenizer = load_model_and_tokenizer()
73
  sentence_df = load_sentence_data()
74
+ sentence_embeddings = load_sentence_embeddings()
75
+
76
+ faiss.normalize_L2(sentence_embeddings)
77
+
78
+ D = 768
79
+ N = 789188
80
+ Xt = sentence_embeddings[:39000]
81
+ X = sentence_embeddings
82
 
83
+ # Param of PQ
84
+ M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
85
+ nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
86
+ # Param of IVF
87
+ nlist = 1000 # The number of cells (space partition). Typical value is sqrt(N)
88
+ # Param of HNSW
89
+ hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
90
+
91
+ # Setup
92
+ quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
93
+ index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
94
+
95
+ # Train
96
+ index.train(Xt)
97
+
98
+ # Add
99
+ index.add(X)
100
+
101
+ # Search
102
+ index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
103
 
104
  main(model, tokenizer, sentence_df, index)