MrVicente commited on
Commit
bcbb55b
1 Parent(s): f03f9e7

fixed viz function parameter error

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. attention_viz.py +1 -1
app.py CHANGED
@@ -40,7 +40,7 @@ def infer_bart(context, task_type, decoding_type_str):
40
  if decoding_type_str =='default':
41
  response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
42
  else:
43
- response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True)
44
  elif Data_Type(task_type) == Data_Type.ELI5:
45
  response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
46
  else:
 
40
  if decoding_type_str =='default':
41
  response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
42
  else:
43
+ response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True, max_concepts=2)
44
  elif Data_Type(task_type) == Data_Type.ELI5:
45
  response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
46
  else:
attention_viz.py CHANGED
@@ -170,7 +170,7 @@ class AttentionVisualizer:
170
  plt.title(title)
171
  plt.show()
172
 
173
- def plot_attn_lines_concepts_ids(title, examples, layer, head,
174
  relations_total, width=3, example_sep=3,
175
  word_height=1, pad=0.1, hide_sep=False):
176
  # examples -> {'words': tokens, 'attentions': [layer][head]}
 
170
  plt.title(title)
171
  plt.show()
172
 
173
+ def plot_attn_lines_concepts_ids(self, title, examples, layer, head,
174
  relations_total, width=3, example_sep=3,
175
  word_height=1, pad=0.1, hide_sep=False):
176
  # examples -> {'words': tokens, 'attentions': [layer][head]}