Fralet commited on
Commit
be40e08
·
verified ·
1 Parent(s): ec15243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -131
app.py CHANGED
@@ -1,141 +1,43 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
- import torch
4
- import numpy as np
5
- import contextlib
6
- import plotly.express as px
7
  import pandas as pd
8
- from PIL import Image
9
- import datetime
10
- import os
11
- import psutil
12
 
13
- with open("hit_log.txt", mode='a') as file:
14
- file.write(str(datetime.datetime.now()) + '\n')
15
 
16
- MODEL_DESC = {
17
- 'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
18
- 'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
19
- 'XLM Roberta XNLI (cross-lingual)': """XLM Roberta, a cross-lingual model, with a classification head trained on XNLI. Supported languages include: _English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu_.
20
 
21
- Note that this model seems to be less reliable than the English-only models when classifying longer sequences.
22
 
23
- Examples were automatically translated and may contain grammatical mistakes.
 
 
24
 
25
- Sequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
26
- }
27
 
28
- ZSL_DESC = """Recently, the NLP science community has begun to pay increasing attention to zero-shot and few-shot applications, such as in the [paper from OpenAI](https://arxiv.org/abs/2005.14165) introducing GPT-3. This demo shows how 🤗 Transformers can be used for zero-shot topic classification, the task of predicting a topic that the model has not been trained on."""
29
-
30
- CODE_DESC = """```python
31
- from transformers import pipeline
32
- classifier = pipeline('zero-shot-classification',
33
- model='{}')
34
- hypothesis_template = 'This text is about {{}}.' # the template used in this demo
35
-
36
- classifier(sequence, labels,
37
- hypothesis_template=hypothesis_template,
38
- multi_class=multi_class)
39
- # {{'sequence' ..., 'labels': ..., 'scores': ...}}
40
- ```"""
41
-
42
- model_ids = {
43
- 'Bart MNLI': 'facebook/bart-large-mnli',
44
- 'Bart MNLI + Yahoo Answers': 'joeddav/bart-large-mnli-yahoo-answers',
45
- 'XLM Roberta XNLI (cross-lingual)': 'joeddav/xlm-roberta-large-xnli'
46
- }
47
-
48
- device = 0 if torch.cuda.is_available() else -1
49
-
50
- @st.cache(allow_output_mutation=True)
51
- def load_models():
52
- return {id: AutoModelForSequenceClassification.from_pretrained(id) for id in model_ids.values()}
53
-
54
- models = load_models()
55
-
56
-
57
- @st.cache(allow_output_mutation=True, show_spinner=False)
58
- def load_tokenizer(tok_id):
59
- return AutoTokenizer.from_pretrained(tok_id)
60
-
61
- @st.cache(allow_output_mutation=True, show_spinner=False)
62
- def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class, do_print_code):
63
- classifier = pipeline('zero-shot-classification', model=models[nli_model_id], tokenizer=load_tokenizer(nli_model_id), device=device)
64
- outputs = classifier(sequence, labels, hypothesis_template, multi_class)
65
- return outputs['labels'], outputs['scores']
66
-
67
- def load_examples(model_id):
68
- model_id_stripped = model_id.split('/')[-1]
69
- df = pd.read_json(f'texts-{model_id_stripped}.json')
70
- names = df.name.values.tolist()
71
- mapping = {df['name'].iloc[i]: (df['text'].iloc[i], df['labels'].iloc[i]) for i in range(len(names))}
72
- names.append('Custom')
73
- mapping['Custom'] = ('', '')
74
- return names, mapping
75
-
76
- def plot_result(top_topics, scores):
77
- top_topics = np.array(top_topics)
78
- scores = np.array(scores)
79
- scores *= 100
80
- fig = px.bar(x=scores, y=top_topics, orientation='h',
81
- labels={'x': 'Confidence', 'y': 'Label'},
82
- text=scores,
83
- range_x=(0,115),
84
- title='Top Predictions',
85
- color=np.linspace(0,1,len(scores)),
86
- color_continuous_scale='GnBu')
87
- fig.update(layout_coloraxis_showscale=False)
88
- fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
89
- st.plotly_chart(fig)
90
 
 
 
 
 
 
 
 
 
 
 
91
 
92
-
93
- def main():
94
- with open("style.css") as f:
95
- st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
96
-
97
- logo = Image.open('huggingface_logo.png')
98
- st.sidebar.image(logo, width=120)
99
- st.sidebar.markdown(ZSL_DESC)
100
- model_desc = st.sidebar.selectbox('Model', list(MODEL_DESC.keys()), 0)
101
- do_print_code = st.sidebar.checkbox('Show code snippet', False)
102
- st.sidebar.markdown('#### Model Description')
103
- st.sidebar.markdown(MODEL_DESC[model_desc])
104
- st.sidebar.markdown('Originally proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161). Read more in our [blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html).')
105
-
106
- model_id = model_ids[model_desc]
107
- ex_names, ex_map = load_examples(model_id)
108
-
109
- st.title('Zero Shot Topic Classification')
110
- example = st.selectbox('Choose an example', ex_names)
111
- height = min((len(ex_map[example][0].split()) + 1) * 2, 200)
112
- sequence = st.text_area('Text', ex_map[example][0], key='sequence', height=height)
113
- labels = st.text_input('Possible topics (separated by `,`)', ex_map[example][1], max_chars=1000)
114
- multi_class = st.checkbox('Allow multiple correct topics', value=True)
115
-
116
- hypothesis_template = "This text is about {}."
117
-
118
- labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
119
- if len(labels) == 0 or len(sequence) == 0:
120
- st.write('Enter some text and at least one possible topic to see predictions.')
121
- return
122
-
123
- if do_print_code:
124
- st.markdown(CODE_DESC.format(model_id))
125
-
126
- with st.spinner('Classifying...'):
127
- top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class, do_print_code)
128
-
129
- plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
130
-
131
- if "socat" not in [p.name() for p in psutil.process_iter()]:
132
- os.system('socat tcp-listen:8000,reuseaddr,fork tcp:localhost:8001 &')
133
-
134
-
135
-
136
-
137
-
138
-
139
- if __name__ == '__main__':
140
- main()
141
-
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from torch.nn import functional as F
 
 
 
4
  import pandas as pd
 
 
 
 
5
 
6
+ tokenizer = AutoTokenizer.from_pretrained('deepset/sentence_bert')
7
+ model = AutoModel.from_pretrained('deepset/sentence_bert')
8
 
9
+ st.title('Semantic Similarity Checker')
 
 
 
10
 
 
11
 
12
+ sentence = st.text_area('Enter')
13
+ user_labels = st.text_area('Enter label')
14
+ labels = [label.strip() for label in user_labels.split(',')]
15
 
16
+ # run inputs through model and mean-pool over the sequence
17
+ # dimension to get sequence-level representations
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ if st.button("Calculate Similarities"):
21
+ if sentence and labels:
22
+ inputs = tokenizer.batch_encode_plus([sentence] + labels,
23
+ return_tensors='pt',
24
+ pad_to_max_length=True)
25
+ input_ids = inputs['input_ids']
26
+ attention_mask = inputs['attention_mask']
27
+ output = model(input_ids, attention_mask=attention_mask)[0]
28
+ sentence_rep = output[:1].mean(dim=1)
29
+ label_reps = output[1:].mean(dim=1)
30
 
31
+ similarities = F.cosine_similarity(sentence_rep, label_reps)
32
+ similarities = similarities.cpu().detach().numpy() # Convert to numpy array for easier handling
33
+
34
+ # Sorting indices for displaying in order
35
+ sorted_indices = similarities.argsort()[::-1]
36
+ sorted_labels = [labels[idx] for idx in sorted_indices]
37
+ sorted_similarities = similarities[sorted_indices]
38
+
39
+ # Display results in bar chart
40
+ df = pd.DataFrame({'Label': sorted_labels, 'Similarity': sorted_similarities})
41
+ st.bar_chart(df.set_index('Label'))
42
+ else:
43
+ st.error("Please enter both a sentence and some labels.")