File size: 3,904 Bytes
6cf191b
 
 
 
4ae80b2
 
6cf191b
f633022
 
 
 
6cf191b
 
 
 
f03f9e7
 
 
6cf191b
4ae80b2
6cf191b
 
 
 
 
 
4ae80b2
 
 
 
 
 
 
6cf191b
 
 
 
 
4ae80b2
 
 
 
bcbb55b
4ae80b2
 
 
 
6cf191b
 
 
4ae80b2
 
 
 
 
 
 
 
c56dde4
4ae80b2
 
 
6cf191b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f03f9e7
6cf191b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import matplotlib.pyplot as plt

from inference import RelationsInference
from attention_viz import AttentionVisualizer
from utils import KGType, Model_Type, Data_Type

#prep
import nltk
nltk.download('popular')

#############################
#   Constants
#############################

#examples = [["What's the meaning of life?", "eli5", "constraint"],
#            ["boat, water, bird", "commongen", "constraint"],
#            ["What flows under a bridge?", "commonsense_qa", "constraint"]]

commongen_bart = RelationsInference(
    model_path='MrVicente/commonsense_bart_commongen',
    kg_type=KGType.CONCEPTNET,
    model_type=Model_Type.RELATIONS,
    max_length=32
)

qa_bart = RelationsInference(
    model_path='MrVicente/commonsense_bart_absqa',
    kg_type=KGType.CONCEPTNET,
    model_type=Model_Type.RELATIONS,
    max_length=128
)
att_viz = AttentionVisualizer(device='cpu')
#############################
#   Helper
#############################

def infer_bart(context, task_type, decoding_type_str):
    if Data_Type(task_type) == Data_Type.COMMONGEN:
        if decoding_type_str =='default':
            response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
        else:
            response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True, max_concepts=2)
    elif Data_Type(task_type) == Data_Type.ELI5:
        response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
    else:
        raise NotImplementedError()
    return response[0]


def plot_attention(context, task_type, layer, head):
    if Data_Type(task_type) == Data_Type.COMMONGEN:
        model = commongen_bart
    elif Data_Type(task_type) == Data_Type.ELI5:
        model = qa_bart
    else:
        raise NotImplementedError()
    response, examples, relations = model.prepare_context_for_visualization(context)
    fig = att_viz.plot_attn_lines_concepts_ids('Input text importance visualized',
                                 examples,
                                 layer, head,
                                 relations)
    return fig


#############################
#   Interface
#############################

app = gr.Blocks()
with app:
    gr.Markdown(
        """
        # Demo
        ### Test Commonsense Relation-Aware BART (BART-RA) model
    
        Tutorial: <br>
            1) Select the possible model variations and tasks;<br>
            2) Change the inputs and Click the buttons to produce results;<br>
            3) See attention visualisations, by choosing a specific layer and head;<br>
        """)
    with gr.Row():
        context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:')
        model_result_output = gr.Textbox(lines=2, label='Model result:')
    with gr.Column():
        task_type_choice = gr.Radio(
            ["eli5", "commongen"], value="eli5", label="What task do you want to try?"
        )
        decoding_type_choice = gr.Radio(
            ["default", "constraint"], value="default", label="What decoding strategy do you want to use?"
        )
    with gr.Row():
        model_btn = gr.Button(value="See Model Results")
    gr.Markdown(
        """
        ---
        Observe Attention
        """
    )
    with gr.Row():
        with gr.Column():
            layer = gr.Slider(0, 11, 0, step=1, label="Layer")
            head = gr.Slider(0, 15, 0, step=1, label="Head")
        with gr.Column():
            plot_output = gr.Plot()
    with gr.Row():
        vis_btn = gr.Button(value="See Attention Scores")
    model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice],
                    outputs=[model_result_output])
    vis_btn.click(fn=plot_attention, inputs=[context_input, task_type_choice, layer, head], outputs=[plot_output])

if __name__ == '__main__':
    app.launch()