fixed viz function parameter error
Browse files- app.py +1 -1
- attention_viz.py +1 -1
app.py
CHANGED
@@ -40,7 +40,7 @@ def infer_bart(context, task_type, decoding_type_str):
|
|
40 |
if decoding_type_str =='default':
|
41 |
response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
|
42 |
else:
|
43 |
-
response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True)
|
44 |
elif Data_Type(task_type) == Data_Type.ELI5:
|
45 |
response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
|
46 |
else:
|
|
|
40 |
if decoding_type_str =='default':
|
41 |
response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
|
42 |
else:
|
43 |
+
response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True, max_concepts=2)
|
44 |
elif Data_Type(task_type) == Data_Type.ELI5:
|
45 |
response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
|
46 |
else:
|
attention_viz.py
CHANGED
@@ -170,7 +170,7 @@ class AttentionVisualizer:
|
|
170 |
plt.title(title)
|
171 |
plt.show()
|
172 |
|
173 |
-
def plot_attn_lines_concepts_ids(title, examples, layer, head,
|
174 |
relations_total, width=3, example_sep=3,
|
175 |
word_height=1, pad=0.1, hide_sep=False):
|
176 |
# examples -> {'words': tokens, 'attentions': [layer][head]}
|
|
|
170 |
plt.title(title)
|
171 |
plt.show()
|
172 |
|
173 |
+
def plot_attn_lines_concepts_ids(self, title, examples, layer, head,
|
174 |
relations_total, width=3, example_sep=3,
|
175 |
word_height=1, pad=0.1, hide_sep=False):
|
176 |
# examples -> {'words': tokens, 'attentions': [layer][head]}
|