saattrupdan
commited on
Commit
•
7420aa9
1
Parent(s):
a98d1c8
feat: Initial commit, add sentiment app
Browse files- .gitignore +1 -0
- app.py +58 -0
- requirements.txt +66 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.venv/
|
app.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Gradio app that showcases Scandinavian zero-shot text classification models."""
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from transformers import pipeline
|
5 |
+
from luga import language as detect_language
|
6 |
+
|
7 |
+
|
8 |
+
# Load the zero-shot classification pipeline
|
9 |
+
classifier = pipeline(
|
10 |
+
"zero-shot-classification", model="alexandrainst/scandi-nli-large"
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def sentiment_classification(doc: str) -> str:
|
15 |
+
"""Classify text into sentiment categories.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
doc (str):
|
19 |
+
Text to classify.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str:
|
23 |
+
The predicted sentiment category.
|
24 |
+
"""
|
25 |
+
# Detect the language of the text
|
26 |
+
language = detect_language(doc).name
|
27 |
+
|
28 |
+
# Get hypothesis template and candidate labels depending on the language
|
29 |
+
if language == "da":
|
30 |
+
hypothesis_template = "Dette eksempel er {}."
|
31 |
+
candidate_labels = ["positivt", "negativt", "neutralt"]
|
32 |
+
elif language == "sv":
|
33 |
+
hypothesis_template = "Detta exempel är {}."
|
34 |
+
candidate_labels = ["positivt", "negativt", "neutralt"]
|
35 |
+
elif language == "no":
|
36 |
+
hypothesis_template = "Dette eksemplet er {}."
|
37 |
+
candidate_labels = ["positivt", "negativt", "nøytralt"]
|
38 |
+
|
39 |
+
# Run the classifier on the text
|
40 |
+
result = classifier(
|
41 |
+
doc, candidate_labels=candidate_labels, hypothesis_template=hypothesis_template
|
42 |
+
)
|
43 |
+
|
44 |
+
# Return the predicted label
|
45 |
+
return result["labels"][0]
|
46 |
+
|
47 |
+
|
48 |
+
# Create the Gradio interface
|
49 |
+
interface = gr.Interface(
|
50 |
+
fn=sentiment_classification,
|
51 |
+
inputs=gr.inputs.Textbox(lines=5, label="Text"),
|
52 |
+
outputs=gr.outputs.Label(type="text"),
|
53 |
+
title="Scandinavian Zero-Shot Text Classification",
|
54 |
+
description="Classify text into sentiment categories.",
|
55 |
+
)
|
56 |
+
|
57 |
+
# Run the app
|
58 |
+
interface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.3
|
2 |
+
aiosignal==1.3.1
|
3 |
+
anyio==3.6.2
|
4 |
+
async-timeout==4.0.2
|
5 |
+
attrs==22.1.0
|
6 |
+
bcrypt==4.0.1
|
7 |
+
certifi==2022.9.24
|
8 |
+
cffi==1.15.1
|
9 |
+
charset-normalizer==2.1.1
|
10 |
+
click==8.1.3
|
11 |
+
contourpy==1.0.6
|
12 |
+
cryptography==38.0.4
|
13 |
+
cycler==0.11.0
|
14 |
+
fastapi==0.88.0
|
15 |
+
ffmpy==0.3.0
|
16 |
+
filelock==3.8.0
|
17 |
+
fonttools==4.38.0
|
18 |
+
frozenlist==1.3.3
|
19 |
+
fsspec==2022.11.0
|
20 |
+
gradio==3.12.0
|
21 |
+
h11==0.12.0
|
22 |
+
httpcore==0.15.0
|
23 |
+
httpx==0.23.1
|
24 |
+
huggingface-hub==0.11.1
|
25 |
+
idna==3.4
|
26 |
+
Jinja2==3.1.2
|
27 |
+
kiwisolver==1.4.4
|
28 |
+
linkify-it-py==1.0.3
|
29 |
+
markdown-it-py==2.1.0
|
30 |
+
MarkupSafe==2.1.1
|
31 |
+
matplotlib==3.6.2
|
32 |
+
mdit-py-plugins==0.3.1
|
33 |
+
mdurl==0.1.2
|
34 |
+
multidict==6.0.2
|
35 |
+
numpy==1.23.5
|
36 |
+
orjson==3.8.2
|
37 |
+
packaging==21.3
|
38 |
+
pandas==1.5.2
|
39 |
+
paramiko==2.12.0
|
40 |
+
Pillow==9.3.0
|
41 |
+
pycparser==2.21
|
42 |
+
pycryptodome==3.16.0
|
43 |
+
pydantic==1.10.2
|
44 |
+
pydub==0.25.1
|
45 |
+
PyNaCl==1.5.0
|
46 |
+
pyparsing==3.0.9
|
47 |
+
python-dateutil==2.8.2
|
48 |
+
python-multipart==0.0.5
|
49 |
+
pytz==2022.6
|
50 |
+
PyYAML==6.0
|
51 |
+
regex==2022.10.31
|
52 |
+
requests==2.28.1
|
53 |
+
rfc3986==1.5.0
|
54 |
+
six==1.16.0
|
55 |
+
sniffio==1.3.0
|
56 |
+
starlette==0.22.0
|
57 |
+
tokenizers==0.13.2
|
58 |
+
torch==1.12.1
|
59 |
+
tqdm==4.64.1
|
60 |
+
transformers==4.24.0
|
61 |
+
typing_extensions==4.4.0
|
62 |
+
uc-micro-py==1.0.1
|
63 |
+
urllib3==1.26.13
|
64 |
+
uvicorn==0.20.0
|
65 |
+
websockets==10.4
|
66 |
+
yarl==1.8.1
|