|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
from inference import RelationsInference |
|
from attention_viz import AttentionVisualizer |
|
from utils import KGType, Model_Type, Data_Type |
|
|
|
|
|
import nltk |
|
nltk.download('popular') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
commongen_bart = RelationsInference( |
|
model_path='MrVicente/commonsense_bart_commongen', |
|
kg_type=KGType.CONCEPTNET, |
|
model_type=Model_Type.RELATIONS, |
|
max_length=32 |
|
) |
|
|
|
qa_bart = RelationsInference( |
|
model_path='MrVicente/commonsense_bart_absqa', |
|
kg_type=KGType.CONCEPTNET, |
|
model_type=Model_Type.RELATIONS, |
|
max_length=128 |
|
) |
|
att_viz = AttentionVisualizer(device='cpu') |
|
|
|
|
|
|
|
|
|
def infer_bart(context, task_type, decoding_type_str): |
|
if Data_Type(task_type) == Data_Type.COMMONGEN: |
|
if decoding_type_str =='default': |
|
response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False) |
|
else: |
|
response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True, max_concepts=2) |
|
elif Data_Type(task_type) == Data_Type.ELI5: |
|
response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False) |
|
else: |
|
raise NotImplementedError() |
|
return response[0] |
|
|
|
|
|
def plot_attention(context, task_type, layer, head): |
|
if Data_Type(task_type) == Data_Type.COMMONGEN: |
|
model = commongen_bart |
|
elif Data_Type(task_type) == Data_Type.ELI5: |
|
model = qa_bart |
|
else: |
|
raise NotImplementedError() |
|
response, examples, relations = model.prepare_context_for_visualization(context) |
|
fig = att_viz.plot_attn_lines_concepts_ids('Input text importance visualized', |
|
examples, |
|
layer, head, |
|
relations) |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
app = gr.Blocks() |
|
with app: |
|
gr.Markdown( |
|
""" |
|
# Demo |
|
### Test Commonsense Relation-Aware BART (BART-RA) model |
|
|
|
Tutorial: <br> |
|
1) Select the possible model variations and tasks;<br> |
|
2) Change the inputs and Click the buttons to produce results;<br> |
|
3) See attention visualisations, by choosing a specific layer and head;<br> |
|
""") |
|
with gr.Row(): |
|
context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:') |
|
model_result_output = gr.Textbox(lines=2, label='Model result:') |
|
with gr.Column(): |
|
task_type_choice = gr.Radio( |
|
["eli5", "commongen"], value="eli5", label="What task do you want to try?" |
|
) |
|
decoding_type_choice = gr.Radio( |
|
["default", "constraint"], value="default", label="What decoding strategy do you want to use?" |
|
) |
|
with gr.Row(): |
|
model_btn = gr.Button(value="See Model Results") |
|
gr.Markdown( |
|
""" |
|
--- |
|
Observe Attention |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
layer = gr.Slider(0, 11, 0, step=1, label="Layer") |
|
head = gr.Slider(0, 15, 0, step=1, label="Head") |
|
with gr.Column(): |
|
plot_output = gr.Plot() |
|
with gr.Row(): |
|
vis_btn = gr.Button(value="See Attention Scores") |
|
model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice], |
|
outputs=[model_result_output]) |
|
vis_btn.click(fn=plot_attention, inputs=[context_input, task_type_choice, layer, head], outputs=[plot_output]) |
|
|
|
if __name__ == '__main__': |
|
app.launch() |