dlsmallw commited on
Commit
9a0de7d
·
1 Parent(s): cfad95e

Task-314 Fix UI ghosting bug

Browse files
Files changed (1) hide show
  1. app.py +124 -37
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import streamlit as st
 
2
  import pandas as pd
3
- from annotated_text import annotated_text
4
- import time
5
  from scripts.predict import InferenceHandler
 
 
 
6
 
7
- history_df = pd.DataFrame(data=[], columns=['Text', 'Classification', 'Gender', 'Race', 'Sexuality', 'Disability', 'Religion', 'Unspecified'])
8
  rc = None
9
 
10
  @st.cache_data
@@ -14,10 +16,11 @@ def load_inference_handler(api_token):
14
  except:
15
  return None
16
 
 
17
  def extract_data(json_obj):
18
  row_data = []
19
 
20
- row_data.append(json_obj['raw_text'])
21
  row_data.append(json_obj['text_sentiment'])
22
  cat_dict = json_obj['category_sentiments']
23
  for cat in cat_dict.keys():
@@ -27,11 +30,44 @@ def extract_data(json_obj):
27
 
28
  return row_data
29
 
30
- def load_history():
31
- for result in st.session_state.results:
32
- history_df.loc[len(history_df)] = extract_data(result)
33
-
34
- def output_results(res):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  label_dict = {
36
  'Gender': '#4A90E2',
37
  'Race': '#E67E22',
@@ -41,38 +77,91 @@ def output_results(res):
41
  'Unspecified': '#A0A0A0'
42
  }
43
 
44
- with rc:
45
- st.markdown('### Results')
46
- with st.container(border=True):
47
- at_list = []
48
- if res['numerical_sentiment'] == 1:
49
- for entry in res['category_sentiments'].keys():
50
- val = res['category_sentiments'][entry]
51
- if val > 0.0:
52
- perc = val * 100
53
- at_list.append((entry, f'{perc:.2f}%', label_dict[entry]))
54
 
55
- st.markdown(f"#### Text - *\"{res['raw_text']}\"*")
56
- st.markdown(f"#### Classification - {':red' if res['numerical_sentiment'] == 1 else ':green'}[{res['text_sentiment']}]")
 
 
 
57
 
58
- if len(at_list) > 0:
59
- annotated_text(at_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  @st.cache_data
62
  def analyze_text(text):
63
- st.write(f'Text: {text}')
64
  if ih:
65
  res = None
66
  with rc:
67
  with st.spinner("Processing...", show_time=True) as spnr:
68
- time.sleep(5)
69
  res = ih.classify_text(text)
70
  del spnr
71
 
72
  if res is not None:
73
  st.session_state.results.append(res)
74
- history_df.loc[-1] = extract_data(res)
75
- output_results(res)
76
 
77
  st.title('NLPinitiative Text Classifier')
78
 
@@ -84,19 +173,18 @@ API_KEY = st.sidebar.text_input(
84
  )
85
  ih = load_inference_handler(API_KEY)
86
 
87
- tab1, tab2 = st.tabs(['Classifier', 'About This App'])
 
 
 
 
88
 
89
  if "results" not in st.session_state:
90
  st.session_state.results = []
91
-
92
- load_history()
93
 
94
  with tab1:
95
  "Text Classifier for determining if entered text is discriminatory (and the categories of discrimination) or Non-Discriminatory."
96
 
97
- hist_container = st.container()
98
- hist_expander = hist_container.expander('History')
99
-
100
  rc = st.container()
101
  text_form = st.form(key='classifier', clear_on_submit=True, enter_to_submit=True)
102
  with text_form:
@@ -106,11 +194,10 @@ with tab1:
106
  if form_btn and text_area is not None and len(text_area) > 0:
107
  analyze_text(text_area)
108
 
109
- with hist_expander:
110
- st.dataframe(history_df, hide_index=True)
111
-
112
-
113
  with tab2:
 
 
 
114
  st.markdown(
115
  """The NLPinitiative Discriminatory Text Classifier is an advanced
116
  natural language processing tool designed to detect and flag potentially
 
1
  import streamlit as st
2
+ import nest_asyncio
3
  import pandas as pd
4
+ from annotated_text import annotation
 
5
  from scripts.predict import InferenceHandler
6
+ from htbuilder import span, div
7
+
8
+ nest_asyncio.apply()
9
 
 
10
  rc = None
11
 
12
  @st.cache_data
 
16
  except:
17
  return None
18
 
19
+ @st.cache_data
20
  def extract_data(json_obj):
21
  row_data = []
22
 
23
+ row_data.append(json_obj['text_input'])
24
  row_data.append(json_obj['text_sentiment'])
25
  cat_dict = json_obj['category_sentiments']
26
  for cat in cat_dict.keys():
 
30
 
31
  return row_data
32
 
33
+ def load_history(parent_elem):
34
+ with parent_elem:
35
+ for idx, result in enumerate(st.session_state.results):
36
+ text = result['text_input']
37
+ discriminatory = False
38
+
39
+ data = []
40
+ for sent_item in result['results']:
41
+ sentence = sent_item['sentence']
42
+ bin_class = sent_item['binary_classification']['classification']
43
+ pred_class = sent_item['binary_classification']['prediction_class']
44
+ ml_regr = sent_item['multilabel_regression']
45
+
46
+ row_data = [sentence, bin_class]
47
+ if pred_class == 1:
48
+ discriminatory = True
49
+ for cat in ml_regr.keys():
50
+ perc = ml_regr[cat] * 100
51
+ row_data.append(f'{perc:.2f}%')
52
+ else:
53
+ for i in range(6):
54
+ row_data.append(None)
55
+
56
+ data.append(row_data)
57
+ df = pd.DataFrame(data=data, columns=['Sentence', 'Binary Classification', 'Gender', 'Race', 'Sexuality', 'Disability', 'Religion', 'Unspecified'])
58
+
59
+ with st.expander(label=f'Entry #{idx+1}', icon='🔴' if discriminatory else '🟢'):
60
+ st.markdown('<hr style="margin: 0.5em 0 0 0;">', unsafe_allow_html=True)
61
+ st.markdown(
62
+ f"<p style='text-align: center; font-weight: bold; font-style: italic; font-size: medium;'>\"{text}\"</p>",
63
+ unsafe_allow_html=True
64
+ )
65
+ st.markdown('<hr style="margin: 0 0 0.5em 0;">', unsafe_allow_html=True)
66
+ st.markdown('##### Sentence Breakdown:')
67
+ st.dataframe(df)
68
+
69
+
70
+ def build_result_tree(parent_elem, results):
71
  label_dict = {
72
  'Gender': '#4A90E2',
73
  'Race': '#E67E22',
 
77
  'Unspecified': '#A0A0A0'
78
  }
79
 
80
+ discriminatory_sentiment = False
81
+
82
+ sent_details = []
83
+ for result in results['results']:
84
+ sentence = result['sentence']
85
+ bin_class = result['binary_classification']['classification']
86
+ pred_class = result['binary_classification']['prediction_class']
87
+ ml_regr = result['multilabel_regression']
 
 
88
 
89
+ sent_res = {
90
+ 'sentence': sentence,
91
+ 'classification': f'{':red' if pred_class else ':green'}[{bin_class}]',
92
+ 'annotated_categories': []
93
+ }
94
 
95
+ if pred_class == 1:
96
+ discriminatory_sentiment = True
97
+ at_list = []
98
+ for entry in ml_regr.keys():
99
+ val = ml_regr[entry]
100
+ if val > 0.0:
101
+ perc = val * 100
102
+ at_list.append(annotation(body=entry, label=f'{perc:.2f}%', background=label_dict[entry]))
103
+ sent_res['annotated_categories'] = at_list
104
+ sent_details.append(sent_res)
105
+
106
+ with parent_elem:
107
+ st.markdown(f'### Results - {':red[Detected Discriminatory Sentiment]' if discriminatory_sentiment else ':green[No Discriminatory Sentiment Detected]'}')
108
+ with st.container(border=True):
109
+ st.markdown('<hr style="margin: 0.5em 0 0 0;">', unsafe_allow_html=True)
110
+ st.markdown(
111
+ f"<p style='text-align: center; font-weight: bold; font-style: italic; font-size: large;'>\"{results['text_input']}\"</p>",
112
+ unsafe_allow_html=True
113
+ )
114
+ st.markdown('<hr style="margin: 0 0 0.5em 0;">', unsafe_allow_html=True)
115
+
116
+ if discriminatory_sentiment:
117
+ if (len(results['results']) > 1):
118
+ st.markdown('##### Sentence Breakdown:')
119
+ for idx, sent in enumerate(sent_details):
120
+ with st.expander(label=f'Sentence #{idx+1}', icon='🔴' if len(sent['annotated_categories']) > 0 else '🟢', expanded=True):
121
+ st.markdown('<hr style="margin: 0.5em 0 0 0;">', unsafe_allow_html=True)
122
+ st.markdown(
123
+ f"<p style='text-align: center; font-weight: bold; font-style: italic; font-size: large;'>\"{sent['sentence']}\"</p>",
124
+ unsafe_allow_html=True
125
+ )
126
+ st.markdown('<hr style="margin: 0 0 0.5em 0;">', unsafe_allow_html=True)
127
+ st.markdown(f'##### Classification - {sent['classification']}')
128
+
129
+ if len(sent['annotated_categories']) > 0:
130
+ st.markdown(
131
+ div(
132
+ span(' ' if idx != 0 else '')[
133
+ item
134
+ ] for idx, item in enumerate(sent['annotated_categories'])
135
+ ),
136
+ unsafe_allow_html=True
137
+ )
138
+ st.markdown('\n')
139
+ else:
140
+ st.markdown(f"#### Classification - {sent['classification']}")
141
+ if len(sent['annotated_categories']) > 0:
142
+ st.markdown(
143
+ div(
144
+ span(' ' if idx != 0 else '')[
145
+ item
146
+ ] for idx, item in enumerate(sent['annotated_categories'])
147
+ ),
148
+ unsafe_allow_html=True
149
+ )
150
+ st.markdown('\n')
151
 
152
  @st.cache_data
153
  def analyze_text(text):
 
154
  if ih:
155
  res = None
156
  with rc:
157
  with st.spinner("Processing...", show_time=True) as spnr:
158
+ # time.sleep(5)
159
  res = ih.classify_text(text)
160
  del spnr
161
 
162
  if res is not None:
163
  st.session_state.results.append(res)
164
+ build_result_tree(rc, res)
 
165
 
166
  st.title('NLPinitiative Text Classifier')
167
 
 
173
  )
174
  ih = load_inference_handler(API_KEY)
175
 
176
+ tab1 = st.empty()
177
+ tab2 = st.empty()
178
+ tab3 = st.empty()
179
+
180
+ tab1, tab2, tab3 = st.tabs(['Classifier', 'Input History', 'About This App'])
181
 
182
  if "results" not in st.session_state:
183
  st.session_state.results = []
 
 
184
 
185
  with tab1:
186
  "Text Classifier for determining if entered text is discriminatory (and the categories of discrimination) or Non-Discriminatory."
187
 
 
 
 
188
  rc = st.container()
189
  text_form = st.form(key='classifier', clear_on_submit=True, enter_to_submit=True)
190
  with text_form:
 
194
  if form_btn and text_area is not None and len(text_area) > 0:
195
  analyze_text(text_area)
196
 
 
 
 
 
197
  with tab2:
198
+ load_history(tab2)
199
+
200
+ with tab3:
201
  st.markdown(
202
  """The NLPinitiative Discriminatory Text Classifier is an advanced
203
  natural language processing tool designed to detect and flag potentially