doc2query / app.py
Sean MacAvaney
update
d40a755
raw
history blame
1.32 kB
import pandas as pd
import gradio as gr
from pyterrier_doc2query import Doc2Query
from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_D
MODEL = 'macavaney/doc2query-t5-base-msmarco'
doc2query = Doc2Query(MODEL, append=True, num_samples=5)
COLAB_NAME = 'pyterrier_doc2query.ipynb'
COLAB_INSTALL = '''
!pip install -q git+https://github.com/terrier-org/pyterrier
!pip install -q git+https://github.com/terrierteam/pyterrier_doc2query
'''.strip()
def predict(input, model, append, num_samples):
assert model == MODEL
doc2query.append = append
doc2query.num_samples = num_samples
code = f'''import pandas as pd
from pyterrier_doc2query import Doc2Query
doc2query = Doc2Query({repr(model)}, append={append}, num_samples={num_samples})
doc2query({df2code(input)})
'''
return (doc2query(input), code2md(code, COLAB_INSTALL, COLAB_NAME))
interface(
MarkdownFile('README.md'),
Demo(
predict,
EX_D,
[
gr.Dropdown(
choices=[MODEL],
value=MODEL,
label='Model',
interactive=False,
), gr.Checkbox(
value=doc2query.append,
label="Append",
), gr.Slider(
minimum=1,
maximum=10,
value=doc2query.num_samples,
step=1.,
label='# Queries'
)],
),
MarkdownFile('wrapup.md'),
).launch(share=False)