ArthurChen189 commited on
Commit
30ac9ed
1 Parent(s): 428e669

project init

Browse files
Files changed (4) hide show
  1. app.py +65 -0
  2. logo.jpeg +0 -0
  3. packages.txt +1 -0
  4. requirements.txt +16 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ from pyserini.search.lucene import LuceneImpactSearcher
4
+ import streamlit as st
5
+ from pathlib import Path
6
+ import sys
7
+ path_root = Path("./")
8
+ sys.path.append(str(path_root))
9
+
10
+
11
+ encoder_index_map = {
12
+ 'uniCOIL': ('UniCoil', 'index-unicoil'),
13
+ 'SPLADE++ Ensemble Distil': ('SpladePlusPlusEnsembleDistil', 'index-splade-pp-ed'),
14
+ 'SPLADE++ Self Distil': ('SpladePlusPlusSelfDistil', 'index-splade-pp-sd')
15
+ }
16
+
17
+ index = 'index-splade-pp-ed'
18
+ encoder = 'SpladePlusPlusEnsembleDistil'
19
+
20
+ st.set_page_config(page_title="Pyserini with ONNX Runtime",
21
+ page_icon='🌸', layout="centered")
22
+
23
+ cola, colb, colc = st.columns([5, 4, 5])
24
+ with colb:
25
+ st.image("logo.jpeg")
26
+
27
+ colaa, colbb, colcc = st.columns([1, 8, 1])
28
+ with colbb:
29
+ encoder = st.select_slider(
30
+ 'Select a query encoder with ONNX Runtime',
31
+ options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil'])
32
+ st.write('Now Running Encoder: ', encoder)
33
+
34
+ encoder, index = encoder_index_map[encoder]
35
+
36
+ col1, col2 = st.columns([9, 1])
37
+ with col1:
38
+ search_query = st.text_input(label="search query", placeholder="Search")
39
+
40
+ with col2:
41
+ st.write('#')
42
+ button_clicked = st.button("🔎")
43
+
44
+ searcher = LuceneImpactSearcher(
45
+ f'indexes/{index}', f'{encoder}', encoder_type='onnx')
46
+
47
+ if search_query or button_clicked:
48
+ num_results = None
49
+ t_0 = time.time()
50
+ print("search query is:\t", search_query)
51
+ search_results = searcher.search(search_query, k=10)
52
+ search_time = time.time() - t_0
53
+ st.write(
54
+ f'<p align=\"right\" style=\"color:grey;\">Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms</p>', unsafe_allow_html=True)
55
+ for i, result in enumerate(search_results[:10]):
56
+ result_score = result.score
57
+ result_id = result.docid
58
+ output = f'<div class="row"> <div class="column"> <b>Rank</b>: {i+1} </div><div class="column"><b>Document ID</b>: {result_id}</div><div class="column"><b>Score</b>:{result_score:.2f}</div></div>'
59
+
60
+ try:
61
+ st.write(output, unsafe_allow_html=True)
62
+
63
+ except:
64
+ pass
65
+ st.write('---')
logo.jpeg ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ openjdk-11-jdk
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ faiss-cpu
2
+ torch
3
+ Cython>=0.29.21
4
+ numpy>=1.18.1
5
+ pandas>=1.4.0
6
+ pyjnius>=1.4.0
7
+ scikit-learn>=0.22.1
8
+ scipy>=1.4.1
9
+ tqdm
10
+ transformers>=4.6.0
11
+ sentencepiece>=0.1.95
12
+ nmslib>=2.1.1
13
+ onnxruntime>=1.8.1
14
+ lightgbm>=3.3.2
15
+ spacy>=3.2.1
16
+ pyyaml