Spaces:
Running
Running
updates
Browse files
app.py
CHANGED
@@ -7,13 +7,13 @@ import pyterrier as pt
|
|
7 |
pt.init()
|
8 |
import pyt_splade
|
9 |
from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D
|
10 |
-
factory_max = pyt_splade.
|
11 |
-
factory_sum = pyt_splade.
|
12 |
|
13 |
COLAB_NAME = 'pyterrier_splade.ipynb'
|
14 |
COLAB_INSTALL = '''
|
15 |
!pip install -q git+https://github.com/naver/splade
|
16 |
-
!pip install -q git+https://github.com/
|
17 |
'''.strip()
|
18 |
|
19 |
def generate_vis(df, mode='Document'):
|
@@ -24,15 +24,9 @@ def generate_vis(df, mode='Document'):
|
|
24 |
max_score = max(max(t.values()) for t in df['toks'])
|
25 |
for row in df.itertuples(index=False):
|
26 |
if mode == 'Query':
|
27 |
-
tok_scores =
|
28 |
-
|
29 |
-
if key.startswith('#base64('):
|
30 |
-
b64 = re.search('#base64\(([^)]+)\)', key).group(1)
|
31 |
-
del tok_scores[key]
|
32 |
-
key = base64.b64decode(b64).decode()
|
33 |
-
tok_scores[key] = value
|
34 |
max_score = max(tok_scores.values())
|
35 |
-
orig_tokens = factory_max.tokenizer.tokenize(row.query_0)
|
36 |
id = row.qid
|
37 |
else:
|
38 |
tok_scores = row.toks
|
@@ -55,38 +49,36 @@ def generate_vis(df, mode='Document'):
|
|
55 |
|
56 |
def predict_query(input, agg):
|
57 |
code = f'''import pandas as pd
|
58 |
-
import pyterrier as pt ; pt.init()
|
59 |
import pyt_splade
|
60 |
|
61 |
-
splade = pyt_splade.
|
62 |
|
63 |
-
query_pipeline = splade.
|
64 |
|
65 |
query_pipeline({df2code(input)})
|
66 |
'''
|
67 |
pipeline = {
|
68 |
'max': factory_max,
|
69 |
'sum': factory_sum
|
70 |
-
}[agg].
|
71 |
res = pipeline(input)
|
72 |
vis = generate_vis(res, mode='Query')
|
73 |
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
|
74 |
|
75 |
def predict_doc(input, agg):
|
76 |
code = f'''import pandas as pd
|
77 |
-
import pyterrier as pt ; pt.init()
|
78 |
import pyt_splade
|
79 |
|
80 |
-
splade = pyt_splade.
|
81 |
|
82 |
-
doc_pipeline = splade.
|
83 |
|
84 |
doc_pipeline({df2code(input)})
|
85 |
'''
|
86 |
pipeline = {
|
87 |
'max': factory_max,
|
88 |
'sum': factory_sum
|
89 |
-
}[agg].
|
90 |
res = pipeline(input)
|
91 |
vis = generate_vis(res, mode='Document')
|
92 |
res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']]
|
|
|
7 |
pt.init()
|
8 |
import pyt_splade
|
9 |
from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D
|
10 |
+
factory_max = pyt_splade.Splade(agg='max')
|
11 |
+
factory_sum = pyt_splade.Splade(agg='sum')
|
12 |
|
13 |
COLAB_NAME = 'pyterrier_splade.ipynb'
|
14 |
COLAB_INSTALL = '''
|
15 |
!pip install -q git+https://github.com/naver/splade
|
16 |
+
!pip install -q git+https://github.com/cmacdonald/pyt_splade
|
17 |
'''.strip()
|
18 |
|
19 |
def generate_vis(df, mode='Document'):
|
|
|
24 |
max_score = max(max(t.values()) for t in df['toks'])
|
25 |
for row in df.itertuples(index=False):
|
26 |
if mode == 'Query':
|
27 |
+
tok_scores = row.query_toks
|
28 |
+
orig_tokens = factory_max.tokenizer.tokenize(row.text)
|
|
|
|
|
|
|
|
|
|
|
29 |
max_score = max(tok_scores.values())
|
|
|
30 |
id = row.qid
|
31 |
else:
|
32 |
tok_scores = row.toks
|
|
|
49 |
|
50 |
def predict_query(input, agg):
|
51 |
code = f'''import pandas as pd
|
|
|
52 |
import pyt_splade
|
53 |
|
54 |
+
splade = pyt_splade.Splade(agg={agg!r})
|
55 |
|
56 |
+
query_pipeline = splade.query_encoder()
|
57 |
|
58 |
query_pipeline({df2code(input)})
|
59 |
'''
|
60 |
pipeline = {
|
61 |
'max': factory_max,
|
62 |
'sum': factory_sum
|
63 |
+
}[agg].query_encoder()
|
64 |
res = pipeline(input)
|
65 |
vis = generate_vis(res, mode='Query')
|
66 |
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
|
67 |
|
68 |
def predict_doc(input, agg):
|
69 |
code = f'''import pandas as pd
|
|
|
70 |
import pyt_splade
|
71 |
|
72 |
+
splade = pyt_splade.Splade(agg={repr(agg)})
|
73 |
|
74 |
+
doc_pipeline = splade.doc_encoder()
|
75 |
|
76 |
doc_pipeline({df2code(input)})
|
77 |
'''
|
78 |
pipeline = {
|
79 |
'max': factory_max,
|
80 |
'sum': factory_sum
|
81 |
+
}[agg].doc_encoder()
|
82 |
res = pipeline(input)
|
83 |
vis = generate_vis(res, mode='Document')
|
84 |
res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']]
|