espejelomar
commited on
Commit
•
5df619b
1
Parent(s):
bbee703
Add first version
Browse files- README.md +4 -4
- __init__.py +0 -0
- app.py +160 -0
- backend/__init__.py +0 -0
- backend/__pycache__/__init__.cpython-36.pyc +0 -0
- backend/__pycache__/__init__.cpython-38.pyc +0 -0
- backend/__pycache__/config.cpython-36.pyc +0 -0
- backend/__pycache__/config.cpython-38.pyc +0 -0
- backend/__pycache__/inference.cpython-36.pyc +0 -0
- backend/__pycache__/inference.cpython-38.pyc +0 -0
- backend/__pycache__/utils.cpython-36.pyc +0 -0
- backend/__pycache__/utils.cpython-38.pyc +0 -0
- backend/config.py +14 -0
- backend/inference.py +199 -0
- backend/utils.py +46 -0
- data/.DS_Store +0 -0
- data/__init__.py +0 -0
- data/stackoverflow-titles-distilbert-emb.csv +3 -0
- data/stackoverflow-titles.jsonl.gz +3 -0
- requirements.txt +10 -0
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
app_file: app.py
|
8 |
pinned: false
|
|
|
1 |
---
|
2 |
+
title: Sentence Embeddings
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
app_file: app.py
|
8 |
pinned: false
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from backend import inference
|
5 |
+
from backend.config import MODELS_ID, QA_MODELS_ID, SEARCH_MODELS_ID
|
6 |
+
|
7 |
+
st.title('Demo using Flax-Sentence-Tranformers')
|
8 |
+
|
9 |
+
st.sidebar.title('Tasks')
|
10 |
+
menu = st.sidebar.radio("", options=["Sentence Similarity", "Asymmetric QA", "Search / Cluster", 'Identifying misleading vaccine texts'], index=0)
|
11 |
+
|
12 |
+
st.markdown('''
|
13 |
+
|
14 |
+
Hi! This is the demo for the [flax sentence embeddings](https://huggingface.co/flax-sentence-embeddings) created for the **Flax/JAX community week 🤗**.
|
15 |
+
We trained three general-purpose flax-sentence-embeddings models: a **distilroberta base**, a **mpnet base** and a **minilm-l6**.
|
16 |
+
All were trained on all the dataset of the 1B+ train corpus with the v3 setup.
|
17 |
+
|
18 |
+
In addition, we trained 20 models focused on general-purpose, QuestionAnswering and Codesearch.
|
19 |
+
View our models here : https://huggingface.co/flax-sentence-embeddings
|
20 |
+
|
21 |
+
''')
|
22 |
+
|
23 |
+
if menu == "Sentence Similarity":
|
24 |
+
st.header('Sentence Similarity')
|
25 |
+
st.markdown('''
|
26 |
+
**Instructions**: You can compare the similarity of a main text with other texts of your choice. In the background, we'll create an embedding for each text, and then we'll use the cosine similarity function to calculate a similarity metric between our main sentence and the others.
|
27 |
+
|
28 |
+
For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
|
29 |
+
''')
|
30 |
+
select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
|
31 |
+
|
32 |
+
anchor = st.text_input(
|
33 |
+
'Please enter here the main text you want to compare:'
|
34 |
+
)
|
35 |
+
|
36 |
+
n_texts = st.number_input(
|
37 |
+
f'''How many texts you want to compare with: '{anchor}'?''',
|
38 |
+
value=2,
|
39 |
+
min_value=2)
|
40 |
+
|
41 |
+
inputs = []
|
42 |
+
|
43 |
+
for i in range(int(n_texts)):
|
44 |
+
input = st.text_input(f'Text {i + 1}:')
|
45 |
+
|
46 |
+
inputs.append(input)
|
47 |
+
|
48 |
+
if st.button('Tell me the similarity.'):
|
49 |
+
results = {model: inference.text_similarity(anchor, inputs, model, MODELS_ID) for model in select_models}
|
50 |
+
df_results = {model: results[model] for model in results}
|
51 |
+
|
52 |
+
index = [f"{idx + 1}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
|
53 |
+
df_total = pd.DataFrame(index=index)
|
54 |
+
for key, value in df_results.items():
|
55 |
+
df_total[key] = list(value['score'].values)
|
56 |
+
|
57 |
+
st.write('Here are the results for selected models:')
|
58 |
+
st.write(df_total)
|
59 |
+
st.write('Visualize the results of each model:')
|
60 |
+
st.line_chart(df_total)
|
61 |
+
elif menu == "Asymmetric QA":
|
62 |
+
st.header('Asymmetric QA')
|
63 |
+
st.markdown('''
|
64 |
+
**Instructions**: You can compare the Answer likeliness of a given Query with answer candidates of your choice. In the background, we'll create an embedding for each answers, and then we'll use the cosine similarity function to calculate a similarity metric between our query sentence and the others.
|
65 |
+
`mpnet_asymmetric_qa` model works best for hard negative answers or distinguishing similar queries due to separate models applied for encoding questions and answers.
|
66 |
+
|
67 |
+
For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
|
68 |
+
''')
|
69 |
+
|
70 |
+
select_models = st.multiselect("Choose models", options=list(QA_MODELS_ID), default=list(QA_MODELS_ID)[0])
|
71 |
+
|
72 |
+
anchor = st.text_input(
|
73 |
+
'Please enter here the query you want to compare with given answers:',
|
74 |
+
value="What is the weather in Paris?"
|
75 |
+
)
|
76 |
+
|
77 |
+
n_texts = st.number_input(
|
78 |
+
f'''How many answers you want to compare with: '{anchor}'?''',
|
79 |
+
value=10,
|
80 |
+
min_value=2)
|
81 |
+
|
82 |
+
inputs = []
|
83 |
+
|
84 |
+
defaults = ["It is raining in Paris right now with 70 F temperature.", "What is the weather in Berlin?", "I have 3 brothers."]
|
85 |
+
for i in range(int(n_texts)):
|
86 |
+
input = st.text_input(f'Answer {i + 1}:', value=defaults[i] if i < len(defaults) else "")
|
87 |
+
|
88 |
+
inputs.append(input)
|
89 |
+
|
90 |
+
if st.button('Tell me Answer likeliness.'):
|
91 |
+
results = {model: inference.text_similarity(anchor, inputs, model, QA_MODELS_ID) for model in select_models}
|
92 |
+
df_results = {model: results[model] for model in results}
|
93 |
+
|
94 |
+
index = [f"{idx + 1}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
|
95 |
+
df_total = pd.DataFrame(index=index)
|
96 |
+
for key, value in df_results.items():
|
97 |
+
df_total[key] = list(value['score'].values)
|
98 |
+
|
99 |
+
st.write('Here are the results for selected models:')
|
100 |
+
st.write(df_total)
|
101 |
+
st.write('Visualize the results of each model:')
|
102 |
+
st.line_chart(df_total)
|
103 |
+
|
104 |
+
elif menu == "Search / Cluster":
|
105 |
+
st.header('Search / Cluster')
|
106 |
+
st.markdown('''
|
107 |
+
**Instructions**: Make a query for anything related to "Python" and the model you choose will return you similar queries.
|
108 |
+
|
109 |
+
For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
|
110 |
+
''')
|
111 |
+
|
112 |
+
select_models = st.multiselect("Choose models", options=list(SEARCH_MODELS_ID), default=list(SEARCH_MODELS_ID)[0])
|
113 |
+
|
114 |
+
anchor = st.text_input(
|
115 |
+
'Please enter here your query about "Python", we will look for similar ones:',
|
116 |
+
value="How do I sort a dataframe by column"
|
117 |
+
)
|
118 |
+
|
119 |
+
n_texts = st.number_input(
|
120 |
+
f'''How many similar queries you want?''',
|
121 |
+
value=3,
|
122 |
+
min_value=2)
|
123 |
+
|
124 |
+
if st.button('Give me my search.'):
|
125 |
+
results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
|
126 |
+
st.table(pd.DataFrame(results[select_models[0]]).T)
|
127 |
+
|
128 |
+
if st.button('3D Clustering of search result using T-SNE on generated embeddings'):
|
129 |
+
st.write("Currently only works at local due to Spaces / plotly integration.")
|
130 |
+
st.write("Demonstration : https://gyazo.com/1ff0aa438ae533de3b3c63382af7fe80")
|
131 |
+
# fig = inference.text_cluster(anchor, 1000, select_models[0], QA_MODELS_ID)
|
132 |
+
# fig.show()
|
133 |
+
elif menu == "Identifying misleading vaccine texts":
|
134 |
+
st.header('Identifying misleading vaccine texts')
|
135 |
+
st.markdown('''
|
136 |
+
**Instructions**: You can compare the similarity of a given text and key words that identify 'misleading' texts regarding vaccination. In the background, we'll create an embedding for each text, and then we'll use the cosine similarity function to calculate a similarity metric between our main sentence and the keywords.
|
137 |
+
|
138 |
+
We use keywords identified by **Muric, Goran and Wu, Yusong and Ferrara, Emilio (2021), 'COVID-19 Vaccine Hesitancy on Social Media: Building a Public Twitter Dataset of Anti-vaccine Content, Vaccine Misinformation and Conspiracies'**
|
139 |
+
|
140 |
+
For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
|
141 |
+
''')
|
142 |
+
select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
|
143 |
+
|
144 |
+
anchor = st.text_input(
|
145 |
+
'Please enter here the text/tweet you want to evaluate:'
|
146 |
+
)
|
147 |
+
|
148 |
+
if st.button('Tell me the similarity.'):
|
149 |
+
results = {model: inference.tweets_vaccine(anchor, model, MODELS_ID) for model in select_models}
|
150 |
+
df_results = {model: results[model] for model in results}
|
151 |
+
|
152 |
+
|
153 |
+
#index = [f"{idx + 1}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
|
154 |
+
df_total = pd.DataFrame(index=[0])
|
155 |
+
for key, value in df_results.items():
|
156 |
+
df_total[key] = list(value['score'].values)
|
157 |
+
|
158 |
+
st.write('Here are the results for selected models:')
|
159 |
+
st.write(df_total)
|
160 |
+
|
backend/__init__.py
ADDED
File without changes
|
backend/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (159 Bytes). View file
|
|
backend/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (152 Bytes). View file
|
|
backend/__pycache__/config.cpython-36.pyc
ADDED
Binary file (737 Bytes). View file
|
|
backend/__pycache__/config.cpython-38.pyc
ADDED
Binary file (738 Bytes). View file
|
|
backend/__pycache__/inference.cpython-36.pyc
ADDED
Binary file (2.2 kB). View file
|
|
backend/__pycache__/inference.cpython-38.pyc
ADDED
Binary file (5.38 kB). View file
|
|
backend/__pycache__/utils.cpython-36.pyc
ADDED
Binary file (1.54 kB). View file
|
|
backend/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (1.56 kB). View file
|
|
backend/config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
|
2 |
+
mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
|
3 |
+
minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
|
4 |
+
|
5 |
+
QA_MODELS_ID = dict(
|
6 |
+
mpnet_asymmetric_qa = ['flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q',
|
7 |
+
'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
|
8 |
+
mpnet_qa='flax-sentence-embeddings/mpnet_stackexchange_v1',
|
9 |
+
distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
|
10 |
+
)
|
11 |
+
|
12 |
+
SEARCH_MODELS_ID = dict(
|
13 |
+
distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
|
14 |
+
)
|
backend/inference.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import json
|
3 |
+
from collections import Counter
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from sentence_transformers import util
|
11 |
+
from typing import List, Union
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from backend.utils import load_model, filter_questions, load_embeddings
|
15 |
+
from sklearn.manifold import TSNE
|
16 |
+
|
17 |
+
def cos_sim(a, b):
|
18 |
+
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
19 |
+
|
20 |
+
|
21 |
+
# We get similarity between embeddings.
|
22 |
+
def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict):
|
23 |
+
print(model_name)
|
24 |
+
model = load_model(model_name, model_dict)
|
25 |
+
|
26 |
+
# Creating embeddings
|
27 |
+
if hasattr(model, 'encode'):
|
28 |
+
anchor_emb = model.encode(anchor)[None, :]
|
29 |
+
inputs_emb = model.encode(inputs)
|
30 |
+
else:
|
31 |
+
assert len(model) == 2
|
32 |
+
anchor_emb = model[0].encode(anchor)[None, :]
|
33 |
+
inputs_emb = model[1].encode(inputs)
|
34 |
+
|
35 |
+
# Obtaining similarity
|
36 |
+
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
|
37 |
+
|
38 |
+
# Returning a Pandas' dataframe
|
39 |
+
d = {'inputs': inputs,
|
40 |
+
'score': [round(similarity[i], 3) for i in range(len(similarity))]}
|
41 |
+
df = pd.DataFrame(d, columns=['inputs', 'score'])
|
42 |
+
|
43 |
+
return df
|
44 |
+
|
45 |
+
|
46 |
+
# Search
|
47 |
+
def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
48 |
+
# Proceeding with model
|
49 |
+
print(model_name)
|
50 |
+
assert model_name == "distilbert_qa"
|
51 |
+
model = load_model(model_name, model_dict)
|
52 |
+
|
53 |
+
# Creating embeddings
|
54 |
+
query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]
|
55 |
+
|
56 |
+
print("loading embeddings")
|
57 |
+
corpus_emb = load_embeddings()
|
58 |
+
|
59 |
+
# Getting hits
|
60 |
+
hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]
|
61 |
+
|
62 |
+
filtered_posts = filter_questions("python")
|
63 |
+
print(f"{len(filtered_posts)} posts found with tag: python")
|
64 |
+
|
65 |
+
hits_titles = []
|
66 |
+
hits_scores = []
|
67 |
+
urls = []
|
68 |
+
for hit in hits:
|
69 |
+
post = filtered_posts[hit['corpus_id']]
|
70 |
+
hits_titles.append(post['title'])
|
71 |
+
hits_scores.append("{:.3f}".format(hit['score']))
|
72 |
+
urls.append(f"https://stackoverflow.com/q/{post['id']}")
|
73 |
+
|
74 |
+
return hits_titles, hits_scores, urls
|
75 |
+
|
76 |
+
|
77 |
+
def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
78 |
+
# Proceeding with model
|
79 |
+
print(model_name)
|
80 |
+
assert model_name == "distilbert_qa"
|
81 |
+
model = load_model(model_name, model_dict)
|
82 |
+
|
83 |
+
# Creating embeddings
|
84 |
+
query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]
|
85 |
+
|
86 |
+
print("loading embeddings")
|
87 |
+
corpus_emb = load_embeddings()
|
88 |
+
|
89 |
+
# Getting hits
|
90 |
+
hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]
|
91 |
+
|
92 |
+
filtered_posts = filter_questions("python")
|
93 |
+
|
94 |
+
hits_dict = [filtered_posts[hit['corpus_id']] for hit in hits]
|
95 |
+
hits_dict.append(dict(id = '1', title = anchor, tags = ['']))
|
96 |
+
|
97 |
+
hits_emb = torch.stack([corpus_emb[hit['corpus_id']] for hit in hits])
|
98 |
+
hits_emb = torch.cat((hits_emb, query_emb))
|
99 |
+
|
100 |
+
# Dimensionality reduction with t-SNE
|
101 |
+
tsne = TSNE(n_components=3, verbose=1, perplexity=15, n_iter=1000)
|
102 |
+
tsne_results = tsne.fit_transform(hits_emb.cpu())
|
103 |
+
df = pd.DataFrame(hits_dict)
|
104 |
+
tags = list(df['tags'])
|
105 |
+
|
106 |
+
counter = Counter(tags[0])
|
107 |
+
for i in tags[1:]:
|
108 |
+
counter.update(i)
|
109 |
+
|
110 |
+
df_tags = pd.DataFrame(counter.most_common(), columns=['Tag', 'Mentions'])
|
111 |
+
most_common_tags = list(df_tags['Tag'])[1:5]
|
112 |
+
|
113 |
+
labels = []
|
114 |
+
|
115 |
+
for tags_list in list(df['tags']):
|
116 |
+
for common_tag in most_common_tags:
|
117 |
+
if common_tag in tags_list:
|
118 |
+
labels.append(common_tag)
|
119 |
+
break
|
120 |
+
elif common_tag != most_common_tags[-1]:
|
121 |
+
continue
|
122 |
+
else:
|
123 |
+
labels.append('others')
|
124 |
+
|
125 |
+
df['title'] = [post['title'] for post in hits_dict]
|
126 |
+
df['labels'] = labels
|
127 |
+
df['tsne_x'] = tsne_results[:, 0]
|
128 |
+
df['tsne_y'] = tsne_results[:, 1]
|
129 |
+
df['tsne_z'] = tsne_results[:, 2]
|
130 |
+
|
131 |
+
df['size'] = [2 for i in range(len(df))]
|
132 |
+
|
133 |
+
# Making the query bigger than the rest of the observations
|
134 |
+
df['size'][len(df) - 1] = 10
|
135 |
+
df['labels'][len(df) - 1] = 'QUERY'
|
136 |
+
import plotly.express as px
|
137 |
+
|
138 |
+
fig = px.scatter_3d(df, x='tsne_x', y='tsne_y', z='tsne_z', color='labels', size='size',
|
139 |
+
color_discrete_sequence=px.colors.qualitative.D3, hover_data=[df.title])
|
140 |
+
return fig
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
# We get similarity between embeddings.
|
145 |
+
def tweets_vaccine(anchor: str, model_name: str, model_dict: dict):
|
146 |
+
print(model_name)
|
147 |
+
model = load_model(model_name, model_dict)
|
148 |
+
|
149 |
+
# Keywords common in disinformation tweets
|
150 |
+
keywords = '''abolish big pharma,
|
151 |
+
no forced flu shots,
|
152 |
+
antivaccine,
|
153 |
+
No Forced Vaccines,
|
154 |
+
Arrest Bill Gates,
|
155 |
+
not mandatory vaccines,
|
156 |
+
No Vaccine,
|
157 |
+
big pharma mafia,
|
158 |
+
No Vaccine For Me,
|
159 |
+
big pharma kills,
|
160 |
+
no vaccine mandates,
|
161 |
+
parents over pharma,
|
162 |
+
say no to vaccines,
|
163 |
+
stop mandatory vaccination,
|
164 |
+
vaccines are poison,
|
165 |
+
learn the risk,
|
166 |
+
vaccines cause,
|
167 |
+
medical freedom,
|
168 |
+
vaccines kill,
|
169 |
+
medical freedom of choice,
|
170 |
+
vaxxed,
|
171 |
+
my body my choice,
|
172 |
+
vaccines have very dangerous consequences,
|
173 |
+
Vaccines harm your organism'''
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
# Creating embeddings
|
179 |
+
if hasattr(model, 'encode'):
|
180 |
+
anchor_emb = model.encode(anchor)[None, :]
|
181 |
+
inputs_emb = model.encode(keywords)
|
182 |
+
else:
|
183 |
+
assert len(model) == 2
|
184 |
+
anchor_emb = model[0].encode(anchor)[None, :]
|
185 |
+
inputs_emb = model[1].encode(keywords)
|
186 |
+
|
187 |
+
|
188 |
+
# Obtaining similarity
|
189 |
+
similarity = jnp.squeeze(jnp.matmul(anchor_emb, jnp.transpose(inputs_emb)) / (jnp.linalg.norm(anchor_emb) * jnp.linalg.norm(inputs_emb))).tolist()
|
190 |
+
|
191 |
+
# Returning a Pandas' dataframe
|
192 |
+
d = dict(tweet = anchor,
|
193 |
+
score = [round(similarity, 3)])
|
194 |
+
df = pd.DataFrame(d, columns=['tweet', 'score'])
|
195 |
+
|
196 |
+
|
197 |
+
return df
|
198 |
+
|
199 |
+
|
backend/utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
|
10 |
+
|
11 |
+
@st.cache(allow_output_mutation=True)
|
12 |
+
def load_model(model_name, model_dict):
|
13 |
+
assert model_name in model_dict.keys()
|
14 |
+
# Lazy downloading
|
15 |
+
model_ids = model_dict[model_name]
|
16 |
+
if type(model_ids) == str:
|
17 |
+
output = SentenceTransformer(model_ids)
|
18 |
+
elif hasattr(model_ids, '__iter__'):
|
19 |
+
output = [SentenceTransformer(name) for name in model_ids]
|
20 |
+
|
21 |
+
return output
|
22 |
+
|
23 |
+
@st.cache(allow_output_mutation=True)
|
24 |
+
def load_embeddings():
|
25 |
+
# embedding pre-generated
|
26 |
+
corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000))
|
27 |
+
return corpus_emb.float()
|
28 |
+
|
29 |
+
@st.cache(allow_output_mutation=True)
|
30 |
+
def filter_questions(tag, max_questions=10000):
|
31 |
+
posts = []
|
32 |
+
max_posts = 6e6
|
33 |
+
with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn:
|
34 |
+
for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"):
|
35 |
+
posts.append(json.loads(line))
|
36 |
+
|
37 |
+
if len(posts) >= max_posts:
|
38 |
+
break
|
39 |
+
|
40 |
+
filtered_posts = []
|
41 |
+
for post in posts:
|
42 |
+
if tag in post["tags"]:
|
43 |
+
filtered_posts.append(post)
|
44 |
+
if len(filtered_posts) >= max_questions:
|
45 |
+
break
|
46 |
+
return filtered_posts
|
data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/__init__.py
ADDED
File without changes
|
data/stackoverflow-titles-distilbert-emb.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f54b58e7835fac510ef46b8ba38c58c9942d769cace977e42a3bb274344ee9f
|
3 |
+
size 3916646328
|
data/stackoverflow-titles.jsonl.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de83e615c8200395a6c32c9e0fc34cc0ca72d0207a65150d27a4a168ec52dbab
|
3 |
+
size 426579165
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sentence_transformers
|
2 |
+
pandas
|
3 |
+
jax
|
4 |
+
jaxlib
|
5 |
+
streamlit
|
6 |
+
numpy
|
7 |
+
torch
|
8 |
+
scikit-learn
|
9 |
+
plotly
|
10 |
+
matplotlib
|