TRACES commited on
Commit
1623114
·
1 Parent(s): 7d555a4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -19
main.py CHANGED
@@ -23,9 +23,12 @@ def load_models():
23
  with open('models/SVM_model_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
24
  st.session_state.untrue_detector = pickle.load(f)
25
 
26
- st.session_state.bert = pipeline(task="text-classification",
27
  model=BertForSequenceClassification.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN'], num_labels=2),
28
  tokenizer=AutoTokenizer.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN']))
 
 
 
29
 
30
 
31
  def load_content():
@@ -49,13 +52,15 @@ if all([
49
  'untrue_detector_result' not in st.session_state,
50
  'bert_result' not in st.session_state
51
  ]):
52
- st.session_state.gpt_detector_result = ''
53
- st.session_state.gpt_detector_probability = [1, 0]
54
 
55
  st.session_state.untrue_detector_result = ''
56
  st.session_state.untrue_detector_probability = 1
57
 
58
- st.session_state.bert_result = [{'label': '', 'score': 1}]
 
 
59
 
60
  content = load_content()
61
  if 'loaded' not in st.session_state:
@@ -92,25 +97,27 @@ if st.session_state.agree:
92
  content['text_placeholder'][st.session_state.lang]).strip('\n')
93
 
94
  if st.button(content['analyze_button'][st.session_state.lang]):
95
- user_tfidf_disinformation = st.session_state.tfidf_vectorizer_disinformation.transform([user_input])
96
- st.session_state.gpt_detector_result = st.session_state.gpt_detector.predict(user_tfidf_disinformation)[0]
97
- st.session_state.gpt_detector_probability = st.session_state.gpt_detector.predict_proba(user_tfidf_disinformation)[0]
98
 
99
  user_tfidf_untrue_inf = st.session_state.tfidf_vectorizer_untrue_inf.transform([user_input])
100
  st.session_state.untrue_detector_result = st.session_state.untrue_detector.predict(user_tfidf_untrue_inf)[0]
101
  st.session_state.untrue_detector_probability = st.session_state.untrue_detector.predict_proba(user_tfidf_untrue_inf)[0]
102
  st.session_state.untrue_detector_probability = max(st.session_state.untrue_detector_probability[0], st.session_state.untrue_detector_probability[1])
103
 
104
- st.session_state.bert_result = st.session_state.bert(user_input)
 
 
105
 
106
- if st.session_state.gpt_detector_result == 1:
107
- st.warning(content['gpt_getect_yes'][st.session_state.lang] +
108
- str(round(st.session_state.gpt_detector_probability[1] * 100, 2)) +
109
- content['gpt_yes_proba'][st.session_state.lang], icon="⚠️")
110
- else:
111
- st.success(content['gpt_getect_no'][st.session_state.lang] +
112
- str(round(st.session_state.gpt_detector_probability[0] * 100, 2)) +
113
- content['gpt_no_proba'][st.session_state.lang], icon="✅")
114
 
115
  if st.session_state.untrue_detector_result == 0:
116
  st.warning(content['untrue_getect_yes'][st.session_state.lang] +
@@ -121,14 +128,23 @@ if st.session_state.agree:
121
  str(round(st.session_state.untrue_detector_probability * 100, 2)) +
122
  content['untrue_no_proba'][st.session_state.lang], icon="✅")
123
 
124
- if st.session_state.bert_result[0]['label'] == 'LABEL_1':
125
  st.warning(content['bert_yes_1'][st.session_state.lang] +
126
- str(round(st.session_state.bert_result[0]['score'] * 100, 2)) +
127
  content['bert_yes_2'][st.session_state.lang], icon = "⚠️")
128
  else:
129
  st.success(content['bert_no_1'][st.session_state.lang] +
130
- str(round(st.session_state.bert_result[0]['score'] * 100, 2)) +
131
  content['bert_no_2'][st.session_state.lang], icon="✅")
 
 
 
 
 
 
 
 
 
132
 
133
  st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️")
134
 
 
23
  with open('models/SVM_model_untrue_inform_detection_tfidf_bg_0.96_F1_score_3Y_N_Q1_082023.pkl', 'rb') as f:
24
  st.session_state.untrue_detector = pickle.load(f)
25
 
26
+ st.session_state.bert_disinfo = pipeline(task="text-classification",
27
  model=BertForSequenceClassification.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN'], num_labels=2),
28
  tokenizer=AutoTokenizer.from_pretrained("TRACES/private-bert", use_auth_token=os.environ['ACCESS_TOKEN']))
29
+ st.session_state.bert_gpt = pipeline(task="text-classification",
30
+ model=BertForSequenceClassification.from_pretrained("usmiva/bert-deepfake-bg", num_labels=2),
31
+ tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-deepfake-bg"))
32
 
33
 
34
  def load_content():
 
52
  'untrue_detector_result' not in st.session_state,
53
  'bert_result' not in st.session_state
54
  ]):
55
+ # st.session_state.gpt_detector_result = ''
56
+ # st.session_state.gpt_detector_probability = [1, 0]
57
 
58
  st.session_state.untrue_detector_result = ''
59
  st.session_state.untrue_detector_probability = 1
60
 
61
+ st.session_state.bert_disinfo_result = [{'label': '', 'score': 1}]
62
+
63
+ st.session_state.bert_gpt_result = [{'label': '', 'score': 1}]
64
 
65
  content = load_content()
66
  if 'loaded' not in st.session_state:
 
97
  content['text_placeholder'][st.session_state.lang]).strip('\n')
98
 
99
  if st.button(content['analyze_button'][st.session_state.lang]):
100
+ # user_tfidf_disinformation = st.session_state.tfidf_vectorizer_disinformation.transform([user_input])
101
+ # st.session_state.gpt_detector_result = st.session_state.gpt_detector.predict(user_tfidf_disinformation)[0]
102
+ # st.session_state.gpt_detector_probability = st.session_state.gpt_detector.predict_proba(user_tfidf_disinformation)[0]
103
 
104
  user_tfidf_untrue_inf = st.session_state.tfidf_vectorizer_untrue_inf.transform([user_input])
105
  st.session_state.untrue_detector_result = st.session_state.untrue_detector.predict(user_tfidf_untrue_inf)[0]
106
  st.session_state.untrue_detector_probability = st.session_state.untrue_detector.predict_proba(user_tfidf_untrue_inf)[0]
107
  st.session_state.untrue_detector_probability = max(st.session_state.untrue_detector_probability[0], st.session_state.untrue_detector_probability[1])
108
 
109
+ st.session_state.bert_disinfo_result = st.session_state.bert_disinfo(user_input)
110
+
111
+ st.session_state.bert_gpt_result = st.session_state.bert_gpt(user_input)
112
 
113
+ # if st.session_state.gpt_detector_result == 1:
114
+ # st.warning(content['gpt_getect_yes'][st.session_state.lang] +
115
+ # str(round(st.session_state.gpt_detector_probability[1] * 100, 2)) +
116
+ # content['gpt_yes_proba'][st.session_state.lang], icon="⚠️")
117
+ # else:
118
+ # st.success(content['gpt_getect_no'][st.session_state.lang] +
119
+ # str(round(st.session_state.gpt_detector_probability[0] * 100, 2)) +
120
+ # content['gpt_no_proba'][st.session_state.lang], icon="✅")
121
 
122
  if st.session_state.untrue_detector_result == 0:
123
  st.warning(content['untrue_getect_yes'][st.session_state.lang] +
 
128
  str(round(st.session_state.untrue_detector_probability * 100, 2)) +
129
  content['untrue_no_proba'][st.session_state.lang], icon="✅")
130
 
131
+ if st.session_state.bert_disinfo_resultt[0]['label'] == 'LABEL_1':
132
  st.warning(content['bert_yes_1'][st.session_state.lang] +
133
+ str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
134
  content['bert_yes_2'][st.session_state.lang], icon = "⚠️")
135
  else:
136
  st.success(content['bert_no_1'][st.session_state.lang] +
137
+ str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
138
  content['bert_no_2'][st.session_state.lang], icon="✅")
139
+
140
+ if st.session_state.bert_gpt_result[0]['label'] == 'LABEL_1':
141
+ st.warning(content['bert_gpt_1'][st.session_state.lang] +
142
+ str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
143
+ content['bert_gpt_2'][st.session_state.lang], icon = "⚠️")
144
+ else:
145
+ st.success(content['bert_human_1'][st.session_state.lang] +
146
+ str(round(st.session_state.bert_gpt_result[0]['score'] * 100, 2)) +
147
+ content['bert_human_2'][st.session_state.lang], icon="✅")
148
 
149
  st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️")
150