Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,141 +1,43 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import
|
3 |
-
import
|
4 |
-
import numpy as np
|
5 |
-
import contextlib
|
6 |
-
import plotly.express as px
|
7 |
import pandas as pd
|
8 |
-
from PIL import Image
|
9 |
-
import datetime
|
10 |
-
import os
|
11 |
-
import psutil
|
12 |
|
13 |
-
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
18 |
-
'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
19 |
-
'XLM Roberta XNLI (cross-lingual)': """XLM Roberta, a cross-lingual model, with a classification head trained on XNLI. Supported languages include: _English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu_.
|
20 |
|
21 |
-
Note that this model seems to be less reliable than the English-only models when classifying longer sequences.
|
22 |
|
23 |
-
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
|
28 |
-
ZSL_DESC = """Recently, the NLP science community has begun to pay increasing attention to zero-shot and few-shot applications, such as in the [paper from OpenAI](https://arxiv.org/abs/2005.14165) introducing GPT-3. This demo shows how 🤗 Transformers can be used for zero-shot topic classification, the task of predicting a topic that the model has not been trained on."""
|
29 |
-
|
30 |
-
CODE_DESC = """```python
|
31 |
-
from transformers import pipeline
|
32 |
-
classifier = pipeline('zero-shot-classification',
|
33 |
-
model='{}')
|
34 |
-
hypothesis_template = 'This text is about {{}}.' # the template used in this demo
|
35 |
-
|
36 |
-
classifier(sequence, labels,
|
37 |
-
hypothesis_template=hypothesis_template,
|
38 |
-
multi_class=multi_class)
|
39 |
-
# {{'sequence' ..., 'labels': ..., 'scores': ...}}
|
40 |
-
```"""
|
41 |
-
|
42 |
-
model_ids = {
|
43 |
-
'Bart MNLI': 'facebook/bart-large-mnli',
|
44 |
-
'Bart MNLI + Yahoo Answers': 'joeddav/bart-large-mnli-yahoo-answers',
|
45 |
-
'XLM Roberta XNLI (cross-lingual)': 'joeddav/xlm-roberta-large-xnli'
|
46 |
-
}
|
47 |
-
|
48 |
-
device = 0 if torch.cuda.is_available() else -1
|
49 |
-
|
50 |
-
@st.cache(allow_output_mutation=True)
|
51 |
-
def load_models():
|
52 |
-
return {id: AutoModelForSequenceClassification.from_pretrained(id) for id in model_ids.values()}
|
53 |
-
|
54 |
-
models = load_models()
|
55 |
-
|
56 |
-
|
57 |
-
@st.cache(allow_output_mutation=True, show_spinner=False)
|
58 |
-
def load_tokenizer(tok_id):
|
59 |
-
return AutoTokenizer.from_pretrained(tok_id)
|
60 |
-
|
61 |
-
@st.cache(allow_output_mutation=True, show_spinner=False)
|
62 |
-
def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class, do_print_code):
|
63 |
-
classifier = pipeline('zero-shot-classification', model=models[nli_model_id], tokenizer=load_tokenizer(nli_model_id), device=device)
|
64 |
-
outputs = classifier(sequence, labels, hypothesis_template, multi_class)
|
65 |
-
return outputs['labels'], outputs['scores']
|
66 |
-
|
67 |
-
def load_examples(model_id):
|
68 |
-
model_id_stripped = model_id.split('/')[-1]
|
69 |
-
df = pd.read_json(f'texts-{model_id_stripped}.json')
|
70 |
-
names = df.name.values.tolist()
|
71 |
-
mapping = {df['name'].iloc[i]: (df['text'].iloc[i], df['labels'].iloc[i]) for i in range(len(names))}
|
72 |
-
names.append('Custom')
|
73 |
-
mapping['Custom'] = ('', '')
|
74 |
-
return names, mapping
|
75 |
-
|
76 |
-
def plot_result(top_topics, scores):
|
77 |
-
top_topics = np.array(top_topics)
|
78 |
-
scores = np.array(scores)
|
79 |
-
scores *= 100
|
80 |
-
fig = px.bar(x=scores, y=top_topics, orientation='h',
|
81 |
-
labels={'x': 'Confidence', 'y': 'Label'},
|
82 |
-
text=scores,
|
83 |
-
range_x=(0,115),
|
84 |
-
title='Top Predictions',
|
85 |
-
color=np.linspace(0,1,len(scores)),
|
86 |
-
color_continuous_scale='GnBu')
|
87 |
-
fig.update(layout_coloraxis_showscale=False)
|
88 |
-
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
89 |
-
st.plotly_chart(fig)
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
model_id = model_ids[model_desc]
|
107 |
-
ex_names, ex_map = load_examples(model_id)
|
108 |
-
|
109 |
-
st.title('Zero Shot Topic Classification')
|
110 |
-
example = st.selectbox('Choose an example', ex_names)
|
111 |
-
height = min((len(ex_map[example][0].split()) + 1) * 2, 200)
|
112 |
-
sequence = st.text_area('Text', ex_map[example][0], key='sequence', height=height)
|
113 |
-
labels = st.text_input('Possible topics (separated by `,`)', ex_map[example][1], max_chars=1000)
|
114 |
-
multi_class = st.checkbox('Allow multiple correct topics', value=True)
|
115 |
-
|
116 |
-
hypothesis_template = "This text is about {}."
|
117 |
-
|
118 |
-
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
119 |
-
if len(labels) == 0 or len(sequence) == 0:
|
120 |
-
st.write('Enter some text and at least one possible topic to see predictions.')
|
121 |
-
return
|
122 |
-
|
123 |
-
if do_print_code:
|
124 |
-
st.markdown(CODE_DESC.format(model_id))
|
125 |
-
|
126 |
-
with st.spinner('Classifying...'):
|
127 |
-
top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class, do_print_code)
|
128 |
-
|
129 |
-
plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
|
130 |
-
|
131 |
-
if "socat" not in [p.name() for p in psutil.process_iter()]:
|
132 |
-
os.system('socat tcp-listen:8000,reuseaddr,fork tcp:localhost:8001 &')
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
if __name__ == '__main__':
|
140 |
-
main()
|
141 |
-
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
from torch.nn import functional as F
|
|
|
|
|
|
|
4 |
import pandas as pd
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained('deepset/sentence_bert')
|
7 |
+
model = AutoModel.from_pretrained('deepset/sentence_bert')
|
8 |
|
9 |
+
st.title('Semantic Similarity Checker')
|
|
|
|
|
|
|
10 |
|
|
|
11 |
|
12 |
+
sentence = st.text_area('Enter')
|
13 |
+
user_labels = st.text_area('Enter label')
|
14 |
+
labels = [label.strip() for label in user_labels.split(',')]
|
15 |
|
16 |
+
# run inputs through model and mean-pool over the sequence
|
17 |
+
# dimension to get sequence-level representations
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
if st.button("Calculate Similarities"):
|
21 |
+
if sentence and labels:
|
22 |
+
inputs = tokenizer.batch_encode_plus([sentence] + labels,
|
23 |
+
return_tensors='pt',
|
24 |
+
pad_to_max_length=True)
|
25 |
+
input_ids = inputs['input_ids']
|
26 |
+
attention_mask = inputs['attention_mask']
|
27 |
+
output = model(input_ids, attention_mask=attention_mask)[0]
|
28 |
+
sentence_rep = output[:1].mean(dim=1)
|
29 |
+
label_reps = output[1:].mean(dim=1)
|
30 |
|
31 |
+
similarities = F.cosine_similarity(sentence_rep, label_reps)
|
32 |
+
similarities = similarities.cpu().detach().numpy() # Convert to numpy array for easier handling
|
33 |
+
|
34 |
+
# Sorting indices for displaying in order
|
35 |
+
sorted_indices = similarities.argsort()[::-1]
|
36 |
+
sorted_labels = [labels[idx] for idx in sorted_indices]
|
37 |
+
sorted_similarities = similarities[sorted_indices]
|
38 |
+
|
39 |
+
# Display results in bar chart
|
40 |
+
df = pd.DataFrame({'Label': sorted_labels, 'Similarity': sorted_similarities})
|
41 |
+
st.bar_chart(df.set_index('Label'))
|
42 |
+
else:
|
43 |
+
st.error("Please enter both a sentence and some labels.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|