HarryLee commited on
Commit
dfed715
·
0 Parent(s):

Duplicate from HarryLee/QueryExpansion

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +100 -0
  4. etsy-embeddings-cpu.pkl +3 -0
  5. requirements.txt +7 -0
  6. top.png +0 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ 000000000001.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: QueryExpansion
3
+ emoji: 👁
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.17.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: HarryLee/QueryExpansion
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_tags import st_tags, st_tags_sidebar
3
+ from keytotext import pipeline
4
+ from PIL import Image
5
+
6
+ import json
7
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
8
+ import gzip
9
+ import os
10
+ import torch
11
+ import pickle
12
+
13
+ ############
14
+ ## Main page
15
+ ############
16
+
17
+ st.write("# Code for Query Expansion")
18
+
19
+ st.markdown("***Idea is to build a model which will take query as inputs and generate expansion information as outputs.***")
20
+ image = Image.open('top.png')
21
+ st.image(image)
22
+
23
+ st.sidebar.write("# Parameter Selection")
24
+ maxtags_sidebar = st.sidebar.slider('Number of query allowed?', 1, 10, 1, key='ehikwegrjifbwreuk')
25
+ user_query = st_tags(
26
+ label='# Enter Query:',
27
+ text='Press enter to add more',
28
+ value=['Mother'],
29
+ suggestions=['five', 'six', 'seven', 'eight', 'nine', 'three', 'eleven', 'ten', 'four'],
30
+ maxtags=maxtags_sidebar,
31
+ key="aljnf")
32
+
33
+ # Add selectbox in streamlit
34
+ option1 = st.sidebar.selectbox(
35
+ 'Which transformers model would you like to be selected?',
36
+ ('multi-qa-MiniLM-L6-cos-v1','null','null'))
37
+
38
+ option2 = st.sidebar.selectbox(
39
+ 'Which corss-encoder model would you like to be selected?',
40
+ ('cross-encoder/ms-marco-MiniLM-L-6-v2','null','null'))
41
+
42
+ st.sidebar.success("Load Successfully!")
43
+
44
+ #if not torch.cuda.is_available():
45
+ # print("Warning: No GPU found. Please add GPU to your notebook")
46
+
47
+ #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
48
+ bi_encoder = SentenceTransformer(option1)
49
+ bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
50
+ top_k = 32 #Number of passages we want to retrieve with the bi-encoder
51
+
52
+ #The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
53
+ cross_encoder = CrossEncoder(option2)
54
+
55
+ # load pre-train embeedings files
56
+ embedding_cache_path = 'etsy-embeddings-cpu.pkl'
57
+ print("Load pre-computed embeddings from disc")
58
+ with open(embedding_cache_path, "rb") as fIn:
59
+ cache_data = pickle.load(fIn)
60
+ #corpus_sentences = cache_data['sentences']
61
+ corpus_embeddings = cache_data['embeddings']
62
+
63
+ # This function will search all wikipedia articles for passages that
64
+ # answer the query
65
+ def search(query):
66
+ print("Input question:", query)
67
+ ##### Sematic Search #####
68
+ # Encode the query using the bi-encoder and find potentially relevant passages
69
+ query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
70
+ #query_embedding = query_embedding.cuda()
71
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
72
+ hits = hits[0] # Get the hits for the first query
73
+
74
+ ##### Re-Ranking #####
75
+ # Now, score all retrieved passages with the cross_encoder
76
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
77
+ cross_scores = cross_encoder.predict(cross_inp)
78
+
79
+ # Sort results by the cross-encoder scores
80
+ for idx in range(len(cross_scores)):
81
+ hits[idx]['cross-score'] = cross_scores[idx]
82
+
83
+ # Output of top-10 hits from bi-encoder
84
+ print("\n-------------------------\n")
85
+ print("Top-10 Bi-Encoder Retrieval hits")
86
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
87
+ for hit in hits[0:10]:
88
+ print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
89
+
90
+ # Output of top-10 hits from re-ranker
91
+ print("\n-------------------------\n")
92
+ print("Top-10 Cross-Encoder Re-ranker hits")
93
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
94
+ for hit in hits[0:10]:
95
+ print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
96
+
97
+ st.write("## Results:")
98
+ if st.button('Generate Sentence'):
99
+ out = search(query = user_query)
100
+ st.success(out)
etsy-embeddings-cpu.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a8eb36f4ec40a7d1cb382376afc38cac7caed6104bbaf5a8b28f8a98ba18cb5
3
+ size 456491627
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit==0.82.0
2
+ streamlit_tags
3
+ pyarrow
4
+ keytotext
5
+ opencv-python-headless
6
+ sentence-transformers
7
+ rank_bm25
top.png ADDED