Spaces:
Runtime error
Runtime error
File size: 2,021 Bytes
68be317 cba50c7 871af30 cba50c7 adbdb15 cba50c7 506b8cf cba50c7 871af30 cba50c7 68be317 cba50c7 0f58367 cba50c7 0f58367 871af30 cba50c7 adbdb15 cba50c7 0f58367 cba50c7 adbdb15 506b8cf eb688f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import gradio as gr
from pyterrier_doc2query import Doc2Query
doc2query = Doc2Query()
def df2code(df):
rows = []
for row in df.itertuples(index=False):
rows.append(f' {dict(row._asdict())},')
rows = '\n'.join(rows)
return f'''pd.DataFrame([
{rows}
])'''
def predict(input, model, append, num_samples):
assert model == 'macavaney/doc2query-t5-base-msmarco'
doc2query.append = append
doc2query.num_samples = num_samples
code = f'''
**Code:**
```python
import pandas as pd
from pyterrier_doc2query import Doc2Query
doc2query = Doc2Query({repr(model)}, append={append}, num_samples={num_samples})
doc2query({df2code(input)})
```
'''
return (doc2query(input), code)
gr.Interface(
predict,
inputs=[gr.Dataframe(
headers=["docno", "text"],
datatype=["str", "str"],
col_count=(2, "fixed"),
row_count=1,
wrap=True,
label='Pipeline Input',
value=[['0', 'The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.']],
), gr.Dropdown(
choices=['macavaney/doc2query-t5-base-msmarco'],
value='macavaney/doc2query-t5-base-msmarco',
label='Model',
interactive=False,
), gr.Checkbox(
value=False,
label="Append",
), gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1.,
label='# Queries'
)],
outputs=[gr.Dataframe(
headers=["docno", "text", "querygen"],
datatype=["str", "str", "str"],
col_count=(3, "fixed"),
row_count=1,
wrap=True,
label='Pipeline Output',
value=[["[docno]", "[text]", "[querygen]"]],
), gr.Markdown()],
title="PyTerrier: Doc2Query",
description=open('README.md', 'rt').read().split('---\n')[-1],
allow_flagging='never',
).launch()
|