Emanuel commited on
Commit
6160ca6
·
1 Parent(s): 0894938

First commit

Browse files
Files changed (4) hide show
  1. .gitignore +140 -0
  2. app.py +99 -0
  3. requirements.txt +66 -0
  4. setup.cfg +17 -0
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # pytype static type analyzer
135
+ .pytype/
136
+
137
+ # Cython debug symbols
138
+ cython_debug/
139
+
140
+ .vscode
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
4
+
5
+
6
+ class TwitterEmotionClassifier:
7
+ def __init__(self, model_name: str, model_type: str):
8
+ self.is_gpu = False
9
+ self.model_type = model_type
10
+ device = torch.device("cuda") if self.is_gpu else torch.device("cpu")
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model.to(device)
14
+ model.eval()
15
+ self.bertweet = pipeline(
16
+ "text-classification",
17
+ model=model,
18
+ tokenizer=tokenizer,
19
+ device=self.is_gpu - 1,
20
+ )
21
+ self.deberta = None
22
+ self.emotions = {
23
+ "LABEL_0": "sadness",
24
+ "LABEL_1": "joy",
25
+ "LABEL_2": "love",
26
+ "LABEL_3": "anger",
27
+ "LABEL_4": "fear",
28
+ "LABEL_5": "surprise",
29
+ }
30
+
31
+ def get_model(self, model_type: str):
32
+ if self.model_type == "bertweet" and model_type == self.model_type:
33
+ return self.bertweet
34
+ elif model_type == "deberta":
35
+ if self.deberta:
36
+ return self.deberta
37
+ model = AutoModelForSequenceClassification.from_pretrained(
38
+ "Emanuel/twitter-emotion-deberta-v3-base"
39
+ )
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ "Emanuel/twitter-emotion-deberta-v3-base"
42
+ )
43
+ self.deberta = pipeline(
44
+ "text-classification",
45
+ model=model,
46
+ tokenizer=tokenizer,
47
+ device=self.is_gpu - 1,
48
+ )
49
+ return self.deberta
50
+
51
+ def predict(self, twitter: str, model_type: str):
52
+ classifier = self.get_model(model_type)
53
+ preds = classifier(twitter, return_all_scores=True)
54
+ if preds:
55
+ pred = preds[0]
56
+ res = {
57
+ "Sadness 😢": pred[0]["score"],
58
+ "Joy 😂": pred[1]["score"],
59
+ "Love 💛": pred[2]["score"],
60
+ "Anger 😠": pred[3]["score"],
61
+ "Fear 😱": pred[4]["score"],
62
+ "Surprise 😮": pred[5]["score"],
63
+ }
64
+ return res
65
+ return None
66
+
67
+
68
+ def main():
69
+
70
+ model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet")
71
+ interFace = gr.Interface(
72
+ fn=model.predict,
73
+ inputs=[
74
+ gr.inputs.Textbox(
75
+ placeholder="What's happenning?", label="Tweet content", lines=5
76
+ ),
77
+ gr.inputs.Radio(["bertweet", "deberta"], label="Model"),
78
+ ],
79
+ outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "),
80
+ verbose=True,
81
+ examples=[
82
+ ["This GOT show just remember LOTR times!", "bertweet"],
83
+ ["Man, that my 30 days of training just got a NaN loss!!!", "bertweet"],
84
+ ["I couldn't see 3 Tom Hollands coming...", "bertweet"],
85
+ [
86
+ "There is nothing better than a soul-warming coffee in the morning",
87
+ "bertweet",
88
+ ],
89
+ ["I fear the vanishing gradient a lot", "deberta"],
90
+ ],
91
+ title="Emotion classification with DeBERTa-v3 🤖",
92
+ description="",
93
+ theme="huggingface",
94
+ )
95
+ interFace.launch()
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
requirements.txt ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ analytics-python==1.4.0
2
+ backoff==1.10.0
3
+ bcrypt==3.2.0
4
+ black==21.11b1
5
+ certifi==2021.10.8
6
+ cffi==1.15.0
7
+ charset-normalizer==2.0.7
8
+ click==8.0.3
9
+ cryptography==36.0.0
10
+ cycler==0.11.0
11
+ emoji==1.6.1
12
+ ffmpy==0.3.0
13
+ filelock==3.4.0
14
+ flake8==4.0.1
15
+ Flask==2.0.2
16
+ Flask-CacheBuster==1.0.0
17
+ Flask-Cors==3.0.10
18
+ Flask-Login==0.5.0
19
+ fonttools==4.28.1
20
+ gradio==2.4.5
21
+ huggingface-hub==0.1.2
22
+ idna==3.3
23
+ isort==5.10.1
24
+ itsdangerous==2.0.1
25
+ Jinja2==3.0.3
26
+ joblib==1.1.0
27
+ kiwisolver==1.3.2
28
+ markdown2==2.4.1
29
+ MarkupSafe==2.0.1
30
+ matplotlib==3.5.0
31
+ mccabe==0.6.1
32
+ monotonic==1.6
33
+ mypy-extensions==0.4.3
34
+ numpy==1.21.4
35
+ packaging==21.3
36
+ pandas==1.3.4
37
+ paramiko==2.8.0
38
+ pathspec==0.9.0
39
+ Pillow==8.4.0
40
+ platformdirs==2.4.0
41
+ pycodestyle==2.8.0
42
+ pycparser==2.21
43
+ pycryptodome==3.11.0
44
+ pydub==0.25.1
45
+ pyflakes==2.4.0
46
+ PyNaCl==1.4.0
47
+ pyparsing==3.0.6
48
+ python-dateutil==2.8.2
49
+ pytz==2021.3
50
+ PyYAML==6.0
51
+ regex==2021.11.10
52
+ requests==2.26.0
53
+ sacremoses==0.0.46
54
+ sentencepiece==0.1.96
55
+ setuptools-scm==6.3.2
56
+ six==1.16.0
57
+ tokenizers==0.10.3
58
+ tomli==1.2.2
59
+ torch==1.10.0+cpu
60
+ torchaudio==0.10.0+cpu
61
+ torchvision==0.11.1+cpu
62
+ tqdm==4.62.3
63
+ transformers==4.12.5
64
+ typing-extensions==4.0.0
65
+ urllib3==1.26.7
66
+ Werkzeug==2.0.2
setup.cfg ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ default_section = FIRSTPARTY
3
+ ensure_newline_before_comments = True
4
+ force_grid_wrap = 0
5
+ include_trailing_comma = True
6
+
7
+ line_length = 119
8
+ lines_after_imports = 2
9
+ multi_line_output = 3
10
+ use_parentheses = True
11
+
12
+ [flake8]
13
+ ignore = E203, E501, E741, W503, W605
14
+ max-line-length = 119
15
+
16
+ [tool:pytest]
17
+ doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS