Spaces:
Running
on
Zero
Running
on
Zero
DmitryRyumin
commited on
Commit
β’
f16bb9f
1
Parent(s):
c6923ed
Summary
Browse files- .flake8 +5 -0
- .gitignore +177 -0
- CODE_OF_CONDUCT.md +80 -0
- LICENSE +21 -0
- README.md +8 -5
- app.css +40 -0
- app.py +48 -0
- app/__init__.py +0 -0
- app/components.py +18 -0
- app/config.py +53 -0
- app/data_init.py +56 -0
- app/description.py +17 -0
- app/event_handlers/__init__.py +0 -0
- app/event_handlers/clear.py +31 -0
- app/event_handlers/event_handlers.py +61 -0
- app/event_handlers/submit.py +172 -0
- app/event_handlers/video.py +26 -0
- app/gpu_init.py +10 -0
- app/load_models.py +909 -0
- app/plots.py +115 -0
- app/requirements_app.py +37 -0
- app/tabs.py +154 -0
- app/utils.py +287 -0
- config.toml +66 -0
- images/clear.ico +0 -0
- images/submit.ico +0 -0
- requirements.txt +13 -0
.flake8
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
; https://www.flake8rules.com/
|
2 |
+
|
3 |
+
[flake8]
|
4 |
+
max-line-length = 120
|
5 |
+
ignore = E203, E402, E741, W503
|
.gitignore
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compiled source #
|
2 |
+
###################
|
3 |
+
*.com
|
4 |
+
*.class
|
5 |
+
*.dll
|
6 |
+
*.exe
|
7 |
+
*.o
|
8 |
+
*.so
|
9 |
+
*.pyc
|
10 |
+
|
11 |
+
# Packages #
|
12 |
+
############
|
13 |
+
# it's better to unpack these files and commit the raw source
|
14 |
+
# git has its own built in compression methods
|
15 |
+
*.7z
|
16 |
+
*.dmg
|
17 |
+
*.gz
|
18 |
+
*.iso
|
19 |
+
*.rar
|
20 |
+
#*.tar
|
21 |
+
*.zip
|
22 |
+
|
23 |
+
# Logs and databases #
|
24 |
+
######################
|
25 |
+
*.log
|
26 |
+
*.sqlite
|
27 |
+
|
28 |
+
# OS generated files #
|
29 |
+
######################
|
30 |
+
.DS_Store
|
31 |
+
ehthumbs.db
|
32 |
+
Icon
|
33 |
+
Thumbs.db
|
34 |
+
.tmtags
|
35 |
+
.idea
|
36 |
+
.vscode
|
37 |
+
tags
|
38 |
+
vendor.tags
|
39 |
+
tmtagsHistory
|
40 |
+
*.sublime-project
|
41 |
+
*.sublime-workspace
|
42 |
+
.bundle
|
43 |
+
|
44 |
+
# Byte-compiled / optimized / DLL files
|
45 |
+
__pycache__/
|
46 |
+
*.py[cod]
|
47 |
+
*$py.class
|
48 |
+
|
49 |
+
# C extensions
|
50 |
+
*.so
|
51 |
+
|
52 |
+
# Distribution / packaging
|
53 |
+
.Python
|
54 |
+
build/
|
55 |
+
develop-eggs/
|
56 |
+
dist/
|
57 |
+
downloads/
|
58 |
+
eggs/
|
59 |
+
.eggs/
|
60 |
+
lib/
|
61 |
+
lib64/
|
62 |
+
parts/
|
63 |
+
sdist/
|
64 |
+
var/
|
65 |
+
wheels/
|
66 |
+
pip-wheel-metadata/
|
67 |
+
share/python-wheels/
|
68 |
+
*.egg-info/
|
69 |
+
.installed.cfg
|
70 |
+
*.egg
|
71 |
+
MANIFEST
|
72 |
+
node_modules/
|
73 |
+
|
74 |
+
# PyInstaller
|
75 |
+
# Usually these files are written by a python script from a template
|
76 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
77 |
+
*.manifest
|
78 |
+
*.spec
|
79 |
+
|
80 |
+
# Installer logs
|
81 |
+
pip-log.txt
|
82 |
+
pip-delete-this-directory.txt
|
83 |
+
|
84 |
+
# Unit test / coverage reports
|
85 |
+
htmlcov/
|
86 |
+
.tox/
|
87 |
+
.nox/
|
88 |
+
.coverage
|
89 |
+
.coverage.*
|
90 |
+
.cache
|
91 |
+
nosetests.xml
|
92 |
+
coverage.xml
|
93 |
+
*.cover
|
94 |
+
.hypothesis/
|
95 |
+
.pytest_cache/
|
96 |
+
|
97 |
+
# Translations
|
98 |
+
*.mo
|
99 |
+
*.pot
|
100 |
+
|
101 |
+
# Django stuff:
|
102 |
+
*.log
|
103 |
+
local_settings.py
|
104 |
+
db.sqlite3
|
105 |
+
db.sqlite3-journal
|
106 |
+
|
107 |
+
# Flask stuff:
|
108 |
+
instance/
|
109 |
+
.webassets-cache
|
110 |
+
|
111 |
+
# Scrapy stuff:
|
112 |
+
.scrapy
|
113 |
+
|
114 |
+
# Sphinx documentation
|
115 |
+
docs/_build/
|
116 |
+
|
117 |
+
# PyBuilder
|
118 |
+
target/
|
119 |
+
|
120 |
+
# Jupyter Notebook
|
121 |
+
.ipynb_checkpoints
|
122 |
+
|
123 |
+
# IPython
|
124 |
+
profile_default/
|
125 |
+
ipython_config.py
|
126 |
+
|
127 |
+
# pyenv
|
128 |
+
.python-version
|
129 |
+
|
130 |
+
# pipenv
|
131 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
132 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
133 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
134 |
+
# install all needed dependencies.
|
135 |
+
#Pipfile.lock
|
136 |
+
|
137 |
+
# celery beat schedule file
|
138 |
+
celerybeat-schedule
|
139 |
+
|
140 |
+
# SageMath parsed files
|
141 |
+
*.sage.py
|
142 |
+
|
143 |
+
# Environments
|
144 |
+
.env
|
145 |
+
.venv
|
146 |
+
env/
|
147 |
+
venv/
|
148 |
+
ENV/
|
149 |
+
env.bak/
|
150 |
+
venv.bak/
|
151 |
+
|
152 |
+
# Spyder project settings
|
153 |
+
.spyderproject
|
154 |
+
.spyproject
|
155 |
+
|
156 |
+
# Rope project settings
|
157 |
+
.ropeproject
|
158 |
+
|
159 |
+
# mkdocs documentation
|
160 |
+
/site
|
161 |
+
|
162 |
+
# mypy
|
163 |
+
.mypy_cache/
|
164 |
+
.dmypy.json
|
165 |
+
dmypy.json
|
166 |
+
|
167 |
+
# Pyre type checker
|
168 |
+
.pyre/
|
169 |
+
|
170 |
+
# Custom
|
171 |
+
.gradio/
|
172 |
+
data/
|
173 |
+
models/
|
174 |
+
fonts/
|
175 |
+
notebooks/
|
176 |
+
weights/
|
177 |
+
project_structure.txt
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <[email protected]>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at <https://www.contributor-covenant.org/version/1/4/code-of-conduct.html>
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
<https://www.contributor-covenant.org/faq>
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 HSE
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
---
|
2 |
title: MASAI
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
|
|
7 |
sdk_version: 5.4.0
|
8 |
app_file: app.py
|
9 |
-
|
|
|
|
|
10 |
license: mit
|
11 |
short_description: Intelligent system for Multimodal Affective States Analysis
|
12 |
---
|
13 |
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: MASAI
|
3 |
+
emoji: π
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
python_version: 3.12
|
8 |
sdk_version: 5.4.0
|
9 |
app_file: app.py
|
10 |
+
app_port: 7860
|
11 |
+
header: default
|
12 |
+
pinned: true
|
13 |
license: mit
|
14 |
short_description: Intelligent system for Multimodal Affective States Analysis
|
15 |
---
|
16 |
|
17 |
+
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
app.css
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.noti_err {
|
2 |
+
color: var(--color-accent);
|
3 |
+
}
|
4 |
+
.noti_true {
|
5 |
+
color: #006900;
|
6 |
+
}
|
7 |
+
|
8 |
+
div.app-flex-container {
|
9 |
+
display: flex;
|
10 |
+
align-items: left;
|
11 |
+
gap: 6px;
|
12 |
+
}
|
13 |
+
|
14 |
+
button.submit {
|
15 |
+
display: flex;
|
16 |
+
border: var(--button-border-width) solid var(--button-primary-border-color);
|
17 |
+
background: var(--button-primary-background-fill);
|
18 |
+
color: var(--button-primary-text-color);
|
19 |
+
border-radius: 8px;
|
20 |
+
transition: all 0.3s ease;
|
21 |
+
}
|
22 |
+
|
23 |
+
button.submit[disabled],
|
24 |
+
button.clear[disabled] {
|
25 |
+
cursor: not-allowed;
|
26 |
+
opacity: 0.6;
|
27 |
+
}
|
28 |
+
|
29 |
+
button.submit:hover:not([disabled]) {
|
30 |
+
border-color: var(--button-primary-border-color-hover);
|
31 |
+
background: var(--button-primary-background-fill-hover);
|
32 |
+
color: var(--button-primary-text-color-hover);
|
33 |
+
}
|
34 |
+
|
35 |
+
div.audio:hover label[data-testid="block-label"],
|
36 |
+
div.imgs:hover label[data-testid="block-label"],
|
37 |
+
div.emo-stats:hover label[data-testid="block-label"],
|
38 |
+
div.sent-stats:hover label[data-testid="block-label"] {
|
39 |
+
display: none;
|
40 |
+
}
|
app.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: app.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Main application file.
|
5 |
+
The file defines the Gradio interface, sets up the main blocks and tabs,
|
6 |
+
and includes event handlers for various components.
|
7 |
+
License: MIT License
|
8 |
+
"""
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
# Importing necessary components for the Gradio app
|
13 |
+
from app.config import CONFIG_NAME, config_data, load_tab_creators
|
14 |
+
from app.event_handlers.event_handlers import setup_app_event_handlers
|
15 |
+
import app.tabs
|
16 |
+
|
17 |
+
|
18 |
+
gr.set_static_paths(paths=[config_data.Path_APP / config_data.StaticPaths_IMAGES])
|
19 |
+
|
20 |
+
|
21 |
+
def create_gradio_app() -> gr.Blocks:
|
22 |
+
with gr.Blocks(
|
23 |
+
theme=gr.themes.Default(), css_paths=config_data.AppSettings_CSS_PATH
|
24 |
+
) as gradio_app:
|
25 |
+
tab_results = {}
|
26 |
+
|
27 |
+
available_functions = {
|
28 |
+
attr: getattr(app.tabs, attr)
|
29 |
+
for attr in dir(app.tabs)
|
30 |
+
if callable(getattr(app.tabs, attr)) and attr.endswith("_tab")
|
31 |
+
}
|
32 |
+
|
33 |
+
tab_creators = load_tab_creators(CONFIG_NAME, available_functions)
|
34 |
+
|
35 |
+
for tab_name, create_tab_function in tab_creators.items():
|
36 |
+
with gr.Tab(tab_name):
|
37 |
+
app_instance = create_tab_function()
|
38 |
+
tab_results[tab_name] = app_instance
|
39 |
+
|
40 |
+
keys = list(tab_results.keys())
|
41 |
+
|
42 |
+
setup_app_event_handlers(*(tab_results[keys[0]]))
|
43 |
+
|
44 |
+
return gradio_app
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
create_gradio_app().queue(api_open=False).launch(share=False)
|
app/__init__.py
ADDED
File without changes
|
app/components.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: components.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Utility functions for creating Gradio components.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Importing necessary components for the Gradio app
|
11 |
+
|
12 |
+
|
13 |
+
def html_message(
|
14 |
+
message: str = "", error: bool = True, visible: bool = True
|
15 |
+
) -> gr.HTML:
|
16 |
+
css_class = "noti_err" if error else "noti_true"
|
17 |
+
|
18 |
+
return gr.HTML(value=f"<h3 class='{css_class}'>{message}</h3>", visible=visible)
|
app/config.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: config.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Configuration module for handling settings.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import tomllib
|
9 |
+
from pathlib import Path
|
10 |
+
from collections.abc import Callable
|
11 |
+
from types import SimpleNamespace
|
12 |
+
|
13 |
+
CONFIG_NAME = "config.toml"
|
14 |
+
|
15 |
+
|
16 |
+
def flatten_dict(prefix: str, d: dict) -> dict:
|
17 |
+
result = {}
|
18 |
+
|
19 |
+
for k, v in d.items():
|
20 |
+
result.update(
|
21 |
+
flatten_dict(f"{prefix}{k}_", v)
|
22 |
+
if isinstance(v, dict)
|
23 |
+
else {f"{prefix}{k}": v}
|
24 |
+
)
|
25 |
+
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
def load_tab_creators(
|
30 |
+
file_path: str, available_functions: dict[str, Callable]
|
31 |
+
) -> dict[str, Callable]:
|
32 |
+
with open(file_path, "rb") as f:
|
33 |
+
config = tomllib.load(f)
|
34 |
+
|
35 |
+
tab_creators_data = config.get("TabCreators", {})
|
36 |
+
|
37 |
+
return {key: available_functions[value] for key, value in tab_creators_data.items()}
|
38 |
+
|
39 |
+
|
40 |
+
def load_config(file_path: str) -> SimpleNamespace:
|
41 |
+
with open(file_path, "rb") as f:
|
42 |
+
config = tomllib.load(f)
|
43 |
+
|
44 |
+
config_data = flatten_dict("", config)
|
45 |
+
|
46 |
+
config_namespace = SimpleNamespace(**config_data)
|
47 |
+
|
48 |
+
setattr(config_namespace, "Path_APP", Path(__file__).parent.parent.resolve())
|
49 |
+
|
50 |
+
return config_namespace
|
51 |
+
|
52 |
+
|
53 |
+
config_data = load_config(CONFIG_NAME)
|
app/data_init.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: data_init.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Initial data loading.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Importing necessary components for the Gradio app
|
11 |
+
from app.config import config_data
|
12 |
+
from app.gpu_init import device
|
13 |
+
from app.load_models import (
|
14 |
+
AudioFeatureExtractor,
|
15 |
+
VideoModelLoader,
|
16 |
+
TextFeatureExtractor,
|
17 |
+
)
|
18 |
+
from app.utils import ASRModel
|
19 |
+
|
20 |
+
vad_model, vad_utils = torch.hub.load(
|
21 |
+
repo_or_dir=config_data.StaticPaths_VAD_MODEL,
|
22 |
+
model="silero_vad",
|
23 |
+
force_reload=False,
|
24 |
+
onnx=False,
|
25 |
+
)
|
26 |
+
|
27 |
+
get_speech_timestamps, _, read_audio, _, _ = vad_utils
|
28 |
+
|
29 |
+
audio_model = AudioFeatureExtractor(
|
30 |
+
checkpoint_url=config_data.StaticPaths_HF_MODELS
|
31 |
+
+ config_data.StaticPaths_EMO_SENT_AUDIO_WEIGHTS,
|
32 |
+
folder_path=config_data.StaticPaths_WEIGHTS,
|
33 |
+
device=device,
|
34 |
+
with_features=False,
|
35 |
+
)
|
36 |
+
|
37 |
+
video_model = VideoModelLoader(
|
38 |
+
face_checkpoint_url=config_data.StaticPaths_HF_MODELS
|
39 |
+
+ config_data.StaticPaths_YOLOV8N_FACE,
|
40 |
+
emotion_checkpoint_url=config_data.StaticPaths_HF_MODELS
|
41 |
+
+ config_data.StaticPaths_EMO_AFFECTNET_WEIGHTS,
|
42 |
+
emo_sent_checkpoint_url=config_data.StaticPaths_HF_MODELS
|
43 |
+
+ config_data.StaticPaths_EMO_SENT_VIDEO_WEIGHTS,
|
44 |
+
folder_path=config_data.StaticPaths_WEIGHTS,
|
45 |
+
device=device,
|
46 |
+
)
|
47 |
+
|
48 |
+
text_model = TextFeatureExtractor(
|
49 |
+
checkpoint_url=config_data.StaticPaths_HF_MODELS
|
50 |
+
+ config_data.StaticPaths_EMO_SENT_TEXT_WEIGHTS,
|
51 |
+
folder_path=config_data.StaticPaths_WEIGHTS,
|
52 |
+
device=device,
|
53 |
+
with_features=False,
|
54 |
+
)
|
55 |
+
|
56 |
+
asr = ASRModel(checkpoint_path=config_data.StaticPaths_OPENAI_WHISPER, device=device)
|
app/description.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: description.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Project description for the Gradio app.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
# Importing necessary components for the Gradio app
|
9 |
+
from app.config import config_data
|
10 |
+
|
11 |
+
DESCRIPTION = f"""\
|
12 |
+
# Intelligent system for Multimodal Affective States Analysis (MASAI)
|
13 |
+
|
14 |
+
<div class="app-flex-container">
|
15 |
+
<img src="https://img.shields.io/badge/version-v{config_data.AppSettings_APP_VERSION}-stable" alt="Version">
|
16 |
+
</div>
|
17 |
+
"""
|
app/event_handlers/__init__.py
ADDED
File without changes
|
app/event_handlers/clear.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: clear.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Event handler for Gradio app to clear.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Importing necessary components for the Gradio app
|
11 |
+
from app.config import config_data
|
12 |
+
from app.components import html_message
|
13 |
+
|
14 |
+
|
15 |
+
def event_handler_clear() -> (
|
16 |
+
tuple[gr.Video, gr.Button, gr.Button, gr.HTML, gr.Plot, gr.Plot, gr.Plot, gr.Plot]
|
17 |
+
):
|
18 |
+
return (
|
19 |
+
gr.Video(value=None),
|
20 |
+
gr.Button(interactive=False),
|
21 |
+
gr.Button(interactive=False),
|
22 |
+
html_message(
|
23 |
+
message=config_data.InformationMessages_NOTI_RESULTS[0],
|
24 |
+
error=True,
|
25 |
+
visible=True,
|
26 |
+
),
|
27 |
+
gr.Plot(value=None, visible=False),
|
28 |
+
gr.Plot(value=None, visible=False),
|
29 |
+
gr.Plot(value=None, visible=False),
|
30 |
+
gr.Plot(value=None, visible=False),
|
31 |
+
)
|
app/event_handlers/event_handlers.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: event_handlers.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: File containing functions for configuring event handlers for Gradio components.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Importing necessary components for the Gradio app
|
11 |
+
from app.event_handlers.video import event_handler_video
|
12 |
+
from app.event_handlers.submit import event_handler_submit
|
13 |
+
from app.event_handlers.clear import event_handler_clear
|
14 |
+
|
15 |
+
|
16 |
+
def setup_app_event_handlers(
|
17 |
+
video,
|
18 |
+
clear,
|
19 |
+
submit,
|
20 |
+
noti_results,
|
21 |
+
waveform,
|
22 |
+
faces,
|
23 |
+
emotion_stats,
|
24 |
+
sent_stats,
|
25 |
+
):
|
26 |
+
gr.on(
|
27 |
+
triggers=[video.change, video.upload, video.stop_recording, video.clear],
|
28 |
+
fn=event_handler_video,
|
29 |
+
inputs=[video],
|
30 |
+
outputs=[clear, submit, noti_results],
|
31 |
+
queue=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
submit.click(
|
35 |
+
fn=event_handler_submit,
|
36 |
+
inputs=[video],
|
37 |
+
outputs=[
|
38 |
+
noti_results,
|
39 |
+
waveform,
|
40 |
+
faces,
|
41 |
+
emotion_stats,
|
42 |
+
sent_stats,
|
43 |
+
],
|
44 |
+
queue=True,
|
45 |
+
)
|
46 |
+
|
47 |
+
clear.click(
|
48 |
+
fn=event_handler_clear,
|
49 |
+
inputs=[],
|
50 |
+
outputs=[
|
51 |
+
video,
|
52 |
+
clear,
|
53 |
+
submit,
|
54 |
+
noti_results,
|
55 |
+
waveform,
|
56 |
+
faces,
|
57 |
+
emotion_stats,
|
58 |
+
sent_stats,
|
59 |
+
],
|
60 |
+
queue=True,
|
61 |
+
)
|
app/event_handlers/submit.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: submit.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Event handler for Gradio app to submit.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import pandas as pd
|
10 |
+
import cv2
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
# Importing necessary components for the Gradio app
|
14 |
+
from app.config import config_data
|
15 |
+
from app.utils import (
|
16 |
+
convert_video_to_audio,
|
17 |
+
readetect_speech,
|
18 |
+
slice_audio,
|
19 |
+
find_intersections,
|
20 |
+
calculate_mode,
|
21 |
+
find_nearest_frames,
|
22 |
+
)
|
23 |
+
from app.plots import (
|
24 |
+
get_evenly_spaced_frame_indices,
|
25 |
+
plot_audio,
|
26 |
+
display_frame_info,
|
27 |
+
plot_images,
|
28 |
+
plot_predictions,
|
29 |
+
)
|
30 |
+
from app.data_init import (
|
31 |
+
read_audio,
|
32 |
+
get_speech_timestamps,
|
33 |
+
vad_model,
|
34 |
+
video_model,
|
35 |
+
asr,
|
36 |
+
audio_model,
|
37 |
+
text_model,
|
38 |
+
)
|
39 |
+
from app.load_models import VideoFeatureExtractor
|
40 |
+
from app.components import html_message
|
41 |
+
|
42 |
+
|
43 |
+
def event_handler_submit(
|
44 |
+
video: str,
|
45 |
+
) -> tuple[gr.HTML, gr.Plot, gr.Plot, gr.Plot, gr.Plot]:
|
46 |
+
audio_file_path = convert_video_to_audio(file_path=video, sr=config_data.General_SR)
|
47 |
+
wav, vad_info = readetect_speech(
|
48 |
+
file_path=audio_file_path,
|
49 |
+
read_audio=read_audio,
|
50 |
+
get_speech_timestamps=get_speech_timestamps,
|
51 |
+
vad_model=vad_model,
|
52 |
+
sr=config_data.General_SR,
|
53 |
+
)
|
54 |
+
|
55 |
+
audio_windows = slice_audio(
|
56 |
+
start_time=config_data.General_START_TIME,
|
57 |
+
end_time=int(len(wav)),
|
58 |
+
win_max_length=int(config_data.General_WIN_MAX_LENGTH * config_data.General_SR),
|
59 |
+
win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
|
60 |
+
win_min_length=int(config_data.General_WIN_MIN_LENGTH * config_data.General_SR),
|
61 |
+
)
|
62 |
+
|
63 |
+
intersections = find_intersections(
|
64 |
+
x=audio_windows,
|
65 |
+
y=vad_info,
|
66 |
+
min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
|
67 |
+
)
|
68 |
+
|
69 |
+
vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
|
70 |
+
vfe.preprocess_video()
|
71 |
+
|
72 |
+
transcriptions, total_text = asr(wav, audio_windows)
|
73 |
+
|
74 |
+
window_frames = []
|
75 |
+
preds_emo = []
|
76 |
+
preds_sen = []
|
77 |
+
for w_idx, window in enumerate(audio_windows):
|
78 |
+
a_w = intersections[w_idx]
|
79 |
+
if not a_w["speech"]:
|
80 |
+
a_pred = None
|
81 |
+
else:
|
82 |
+
wave = wav[a_w["start"] : a_w["end"]].clone()
|
83 |
+
a_pred, _ = audio_model(wave)
|
84 |
+
|
85 |
+
v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
|
86 |
+
|
87 |
+
t_pred, _ = text_model(transcriptions[w_idx][0])
|
88 |
+
|
89 |
+
if a_pred:
|
90 |
+
pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
|
91 |
+
pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
|
92 |
+
else:
|
93 |
+
pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
|
94 |
+
pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
|
95 |
+
|
96 |
+
frames = list(
|
97 |
+
range(
|
98 |
+
int(window["start"] * vfe.fps / config_data.General_SR) + 1,
|
99 |
+
int(window["end"] * vfe.fps / config_data.General_SR) + 2,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
|
103 |
+
preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
|
104 |
+
window_frames.extend(frames)
|
105 |
+
|
106 |
+
if max(window_frames) < vfe.frame_number:
|
107 |
+
missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
|
108 |
+
window_frames.extend(missed_frames)
|
109 |
+
preds_emo.extend([preds_emo[-1]] * len(missed_frames))
|
110 |
+
preds_sen.extend([preds_sen[-1]] * len(missed_frames))
|
111 |
+
|
112 |
+
df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
|
113 |
+
df_pred["frames"] = window_frames
|
114 |
+
df_pred["pred_emo"] = preds_emo
|
115 |
+
df_pred["pred_sent"] = preds_sen
|
116 |
+
|
117 |
+
df_pred = df_pred.groupby("frames").agg(
|
118 |
+
{
|
119 |
+
"pred_emo": calculate_mode,
|
120 |
+
"pred_sent": calculate_mode,
|
121 |
+
}
|
122 |
+
)
|
123 |
+
|
124 |
+
frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
|
125 |
+
num_frames = len(wav)
|
126 |
+
time_axis = [i / config_data.General_SR for i in range(num_frames)]
|
127 |
+
plt_audio = plot_audio(time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2))
|
128 |
+
|
129 |
+
all_idx_faces = list(vfe.faces[1].keys())
|
130 |
+
need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
|
131 |
+
faces = []
|
132 |
+
for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
|
133 |
+
cur_face = cv2.resize(
|
134 |
+
vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
|
135 |
+
)
|
136 |
+
faces.append(
|
137 |
+
display_frame_info(
|
138 |
+
cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
|
139 |
+
)
|
140 |
+
)
|
141 |
+
plt_faces = plot_images(faces)
|
142 |
+
|
143 |
+
plt_emo = plot_predictions(
|
144 |
+
df_pred,
|
145 |
+
"pred_emo",
|
146 |
+
"Emotion",
|
147 |
+
list(config_data.General_DICT_EMO),
|
148 |
+
(12, 2.5),
|
149 |
+
[i + 1 for i in frame_indices],
|
150 |
+
2,
|
151 |
+
)
|
152 |
+
plt_sent = plot_predictions(
|
153 |
+
df_pred,
|
154 |
+
"pred_sent",
|
155 |
+
"Sentiment",
|
156 |
+
list(config_data.General_DICT_SENT),
|
157 |
+
(12, 1.5),
|
158 |
+
[i + 1 for i in frame_indices],
|
159 |
+
2,
|
160 |
+
)
|
161 |
+
|
162 |
+
return (
|
163 |
+
html_message(
|
164 |
+
message=config_data.InformationMessages_NOTI_RESULTS[1],
|
165 |
+
error=False,
|
166 |
+
visible=False,
|
167 |
+
),
|
168 |
+
gr.Plot(value=plt_audio, visible=True),
|
169 |
+
gr.Plot(value=plt_faces, visible=True),
|
170 |
+
gr.Plot(value=plt_emo, visible=True),
|
171 |
+
gr.Plot(value=plt_sent, visible=True),
|
172 |
+
)
|
app/event_handlers/video.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: video.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Event handler for Gradio app to video.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Importing necessary components for the Gradio app
|
11 |
+
from app.config import config_data
|
12 |
+
from app.components import html_message
|
13 |
+
|
14 |
+
|
15 |
+
def event_handler_video(video: str) -> gr.HTML:
|
16 |
+
is_video_valid = bool(video)
|
17 |
+
|
18 |
+
return (
|
19 |
+
gr.Button(interactive=is_video_valid),
|
20 |
+
gr.Button(interactive=is_video_valid),
|
21 |
+
html_message(
|
22 |
+
message=config_data.InformationMessages_NOTI_RESULTS[int(is_video_valid)],
|
23 |
+
error=not is_video_valid,
|
24 |
+
visible=True,
|
25 |
+
),
|
26 |
+
)
|
app/gpu_init.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: gpu_init.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: GPU initialization.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
app/load_models.py
ADDED
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: load_models.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Load pretrained models.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch
|
15 |
+
from typing import Optional
|
16 |
+
from PIL import Image
|
17 |
+
from ultralytics import YOLO
|
18 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
19 |
+
Wav2Vec2Model,
|
20 |
+
Wav2Vec2PreTrainedModel,
|
21 |
+
)
|
22 |
+
|
23 |
+
from transformers import AutoConfig, Wav2Vec2Processor, AutoTokenizer, AutoModel
|
24 |
+
|
25 |
+
from app.utils import pth_processing, get_idx_frames_in_windows
|
26 |
+
|
27 |
+
# Importing necessary components for the Gradio app
|
28 |
+
from app.utils import load_model
|
29 |
+
|
30 |
+
|
31 |
+
class ScaledDotProductAttention_MultiHead(nn.Module):
|
32 |
+
def __init__(self):
|
33 |
+
super(ScaledDotProductAttention_MultiHead, self).__init__()
|
34 |
+
self.softmax = nn.Softmax(dim=-1)
|
35 |
+
|
36 |
+
def forward(self, query, key, value, mask=None):
|
37 |
+
if mask is not None:
|
38 |
+
raise ValueError("Mask is not supported yet")
|
39 |
+
|
40 |
+
# key, query, value shapes: [batch_size, num_heads, seq_len, dim]
|
41 |
+
emb_dim = key.shape[-1]
|
42 |
+
|
43 |
+
# Calculate attention weights
|
44 |
+
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
|
45 |
+
emb_dim
|
46 |
+
)
|
47 |
+
|
48 |
+
# masking
|
49 |
+
if mask is not None:
|
50 |
+
raise ValueError("Mask is not supported yet")
|
51 |
+
|
52 |
+
# Softmax
|
53 |
+
attention_weights = self.softmax(attention_weights)
|
54 |
+
|
55 |
+
# modify value
|
56 |
+
value = torch.matmul(attention_weights, value)
|
57 |
+
return value, attention_weights
|
58 |
+
|
59 |
+
|
60 |
+
class PositionWiseFeedForward(nn.Module):
|
61 |
+
def __init__(self, input_dim, hidden_dim, dropout: float = 0.1):
|
62 |
+
super().__init__()
|
63 |
+
self.layer_1 = nn.Linear(input_dim, hidden_dim)
|
64 |
+
self.layer_2 = nn.Linear(hidden_dim, input_dim)
|
65 |
+
self.layer_norm = nn.LayerNorm(input_dim)
|
66 |
+
self.dropout = nn.Dropout(dropout)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
# feed-forward network
|
70 |
+
x = self.layer_1(x)
|
71 |
+
x = self.dropout(x)
|
72 |
+
x = F.relu(x)
|
73 |
+
x = self.layer_2(x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class Add_and_Norm(nn.Module):
|
78 |
+
def __init__(self, input_dim, dropout: Optional[float] = 0.1):
|
79 |
+
super().__init__()
|
80 |
+
self.layer_norm = nn.LayerNorm(input_dim)
|
81 |
+
if dropout is not None:
|
82 |
+
self.dropout = nn.Dropout(dropout)
|
83 |
+
|
84 |
+
def forward(self, x1, residual):
|
85 |
+
x = x1
|
86 |
+
# apply dropout of needed
|
87 |
+
if hasattr(self, "dropout"):
|
88 |
+
x = self.dropout(x)
|
89 |
+
# add and then norm
|
90 |
+
x = x + residual
|
91 |
+
x = self.layer_norm(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class MultiHeadAttention(nn.Module):
|
96 |
+
def __init__(self, input_dim, num_heads, dropout: Optional[float] = 0.1):
|
97 |
+
super().__init__()
|
98 |
+
self.input_dim = input_dim
|
99 |
+
self.num_heads = num_heads
|
100 |
+
if input_dim % num_heads != 0:
|
101 |
+
raise ValueError("input_dim must be divisible by num_heads")
|
102 |
+
self.head_dim = input_dim // num_heads
|
103 |
+
self.dropout = dropout
|
104 |
+
|
105 |
+
# initialize weights
|
106 |
+
self.query_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
|
107 |
+
self.keys_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
|
108 |
+
self.values_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False)
|
109 |
+
self.ff_layer_after_concat = nn.Linear(
|
110 |
+
self.num_heads * self.head_dim, input_dim, bias=False
|
111 |
+
)
|
112 |
+
|
113 |
+
self.attention = ScaledDotProductAttention_MultiHead()
|
114 |
+
|
115 |
+
if self.dropout is not None:
|
116 |
+
self.dropout = nn.Dropout(dropout)
|
117 |
+
|
118 |
+
def forward(self, queries, keys, values, mask=None):
|
119 |
+
# query, keys, values shapes: [batch_size, seq_len, input_dim]
|
120 |
+
batch_size, len_query, len_keys, len_values = (
|
121 |
+
queries.size(0),
|
122 |
+
queries.size(1),
|
123 |
+
keys.size(1),
|
124 |
+
values.size(1),
|
125 |
+
)
|
126 |
+
|
127 |
+
# linear transformation before attention
|
128 |
+
queries = (
|
129 |
+
self.query_w(queries)
|
130 |
+
.view(batch_size, len_query, self.num_heads, self.head_dim)
|
131 |
+
.transpose(1, 2)
|
132 |
+
) # [batch_size, num_heads, seq_len, dim]
|
133 |
+
keys = (
|
134 |
+
self.keys_w(keys)
|
135 |
+
.view(batch_size, len_keys, self.num_heads, self.head_dim)
|
136 |
+
.transpose(1, 2)
|
137 |
+
) # [batch_size, num_heads, seq_len, dim]
|
138 |
+
values = (
|
139 |
+
self.values_w(values)
|
140 |
+
.view(batch_size, len_values, self.num_heads, self.head_dim)
|
141 |
+
.transpose(1, 2)
|
142 |
+
) # [batch_size, num_heads, seq_len, dim]
|
143 |
+
|
144 |
+
# attention itself
|
145 |
+
values, attention_weights = self.attention(
|
146 |
+
queries, keys, values, mask=mask
|
147 |
+
) # values shape:[batch_size, num_heads, seq_len, dim]
|
148 |
+
|
149 |
+
# concatenation
|
150 |
+
out = (
|
151 |
+
values.transpose(1, 2)
|
152 |
+
.contiguous()
|
153 |
+
.view(batch_size, len_values, self.num_heads * self.head_dim)
|
154 |
+
) # [batch_size, seq_len, num_heads * dim = input_dim]
|
155 |
+
# go through last linear layer
|
156 |
+
out = self.ff_layer_after_concat(out)
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
class PositionalEncoding(nn.Module):
|
161 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
162 |
+
super().__init__()
|
163 |
+
self.dropout = nn.Dropout(p=dropout)
|
164 |
+
|
165 |
+
position = torch.arange(max_len).unsqueeze(1)
|
166 |
+
div_term = torch.exp(
|
167 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
168 |
+
)
|
169 |
+
pe = torch.zeros(max_len, 1, d_model)
|
170 |
+
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
171 |
+
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
172 |
+
pe = pe.permute(
|
173 |
+
1, 0, 2
|
174 |
+
) # [seq_len, batch_size, embedding_dim] -> [batch_size, seq_len, embedding_dim]
|
175 |
+
self.register_buffer("pe", pe)
|
176 |
+
|
177 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
x: Tensor, shape [batch_size, seq_len, embedding_dim]
|
181 |
+
"""
|
182 |
+
x = x + self.pe[:, : x.size(1)]
|
183 |
+
return self.dropout(x)
|
184 |
+
|
185 |
+
|
186 |
+
class TransformerLayer(nn.Module):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
input_dim,
|
190 |
+
num_heads,
|
191 |
+
dropout: Optional[float] = 0.1,
|
192 |
+
positional_encoding: bool = True,
|
193 |
+
):
|
194 |
+
super(TransformerLayer, self).__init__()
|
195 |
+
self.positional_encoding = positional_encoding
|
196 |
+
self.input_dim = input_dim
|
197 |
+
self.num_heads = num_heads
|
198 |
+
self.head_dim = input_dim // num_heads
|
199 |
+
self.dropout = dropout
|
200 |
+
|
201 |
+
# initialize layers
|
202 |
+
self.self_attention = MultiHeadAttention(input_dim, num_heads, dropout=dropout)
|
203 |
+
self.feed_forward = PositionWiseFeedForward(
|
204 |
+
input_dim, input_dim, dropout=dropout
|
205 |
+
)
|
206 |
+
self.add_norm_after_attention = Add_and_Norm(input_dim, dropout=dropout)
|
207 |
+
self.add_norm_after_ff = Add_and_Norm(input_dim, dropout=dropout)
|
208 |
+
|
209 |
+
# calculate positional encoding
|
210 |
+
if self.positional_encoding:
|
211 |
+
self.positional_encoding = PositionalEncoding(input_dim)
|
212 |
+
|
213 |
+
def forward(self, key, value, query, mask=None):
|
214 |
+
# key, value, and query shapes: [batch_size, seq_len, input_dim]
|
215 |
+
# positional encoding
|
216 |
+
if self.positional_encoding:
|
217 |
+
key = self.positional_encoding(key)
|
218 |
+
value = self.positional_encoding(value)
|
219 |
+
query = self.positional_encoding(query)
|
220 |
+
|
221 |
+
# multi-head attention
|
222 |
+
residual = query
|
223 |
+
x = self.self_attention(queries=query, keys=key, values=value, mask=mask)
|
224 |
+
x = self.add_norm_after_attention(x, residual)
|
225 |
+
|
226 |
+
# feed forward
|
227 |
+
residual = x
|
228 |
+
x = self.feed_forward(x)
|
229 |
+
x = self.add_norm_after_ff(x, residual)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
class SelfTransformer(nn.Module):
|
235 |
+
def __init__(self, input_size: int = int(1024), num_heads=1, dropout=0.1):
|
236 |
+
super(SelfTransformer, self).__init__()
|
237 |
+
self.att = torch.nn.MultiheadAttention(
|
238 |
+
input_size, num_heads, dropout, bias=True, batch_first=True
|
239 |
+
)
|
240 |
+
self.norm1 = nn.LayerNorm(input_size)
|
241 |
+
self.fcl = nn.Linear(input_size, input_size)
|
242 |
+
self.norm2 = nn.LayerNorm(input_size)
|
243 |
+
|
244 |
+
def forward(self, video):
|
245 |
+
represent, _ = self.att(video, video, video)
|
246 |
+
represent_norm = self.norm1(video + represent)
|
247 |
+
represent_fcl = self.fcl(represent_norm)
|
248 |
+
represent = self.norm1(represent_norm + represent_fcl)
|
249 |
+
return represent
|
250 |
+
|
251 |
+
|
252 |
+
class SmallClassificationHead(nn.Module):
|
253 |
+
"""ClassificationHead"""
|
254 |
+
|
255 |
+
def __init__(self, input_size=256, out_emo=6, out_sen=3):
|
256 |
+
super(SmallClassificationHead, self).__init__()
|
257 |
+
self.fc_emo = nn.Linear(input_size, out_emo)
|
258 |
+
self.fc_sen = nn.Linear(input_size, out_sen)
|
259 |
+
|
260 |
+
def forward(self, x):
|
261 |
+
x_emo = self.fc_emo(x)
|
262 |
+
x_sen = self.fc_sen(x)
|
263 |
+
return {"emo": x_emo, "sen": x_sen}
|
264 |
+
|
265 |
+
|
266 |
+
class AudioModelWT(Wav2Vec2PreTrainedModel):
|
267 |
+
def __init__(self, config):
|
268 |
+
super().__init__(config)
|
269 |
+
self.config = config
|
270 |
+
self.wav2vec2 = Wav2Vec2Model(config)
|
271 |
+
|
272 |
+
self.f_size = 1024
|
273 |
+
|
274 |
+
self.tl1 = TransformerLayer(
|
275 |
+
input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True
|
276 |
+
)
|
277 |
+
self.tl2 = TransformerLayer(
|
278 |
+
input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True
|
279 |
+
)
|
280 |
+
|
281 |
+
self.fc1 = nn.Linear(1024, 1)
|
282 |
+
self.dp = nn.Dropout(p=0.5)
|
283 |
+
|
284 |
+
self.selu = nn.SELU()
|
285 |
+
self.relu = nn.ReLU()
|
286 |
+
self.cl_head = SmallClassificationHead(
|
287 |
+
input_size=199, out_emo=config.out_emo, out_sen=config.out_sen
|
288 |
+
)
|
289 |
+
|
290 |
+
self.init_weights()
|
291 |
+
|
292 |
+
# freeze conv
|
293 |
+
self.freeze_feature_encoder()
|
294 |
+
|
295 |
+
def freeze_feature_encoder(self):
|
296 |
+
for param in self.wav2vec2.feature_extractor.conv_layers.parameters():
|
297 |
+
param.requires_grad = False
|
298 |
+
|
299 |
+
def forward(self, x, with_features=False):
|
300 |
+
outputs = self.wav2vec2(x)
|
301 |
+
|
302 |
+
x = self.tl1(outputs[0], outputs[0], outputs[0])
|
303 |
+
x = self.selu(x)
|
304 |
+
|
305 |
+
features = self.tl2(x, x, x)
|
306 |
+
x = self.selu(features)
|
307 |
+
|
308 |
+
x = self.fc1(x)
|
309 |
+
x = self.relu(x)
|
310 |
+
x = self.dp(x)
|
311 |
+
|
312 |
+
x = x.view(x.size(0), -1)
|
313 |
+
|
314 |
+
if with_features:
|
315 |
+
return self.cl_head(x), features
|
316 |
+
else:
|
317 |
+
return self.cl_head(x)
|
318 |
+
|
319 |
+
|
320 |
+
class AudioFeatureExtractor:
|
321 |
+
def __init__(
|
322 |
+
self,
|
323 |
+
checkpoint_url: str,
|
324 |
+
folder_path: str,
|
325 |
+
device: torch.device,
|
326 |
+
sr: int = 16000,
|
327 |
+
win_max_length: int = 4,
|
328 |
+
with_features: bool = True,
|
329 |
+
) -> None:
|
330 |
+
"""
|
331 |
+
Args:
|
332 |
+
sr (int, optional): Sample rate of audio. Defaults to 16000.
|
333 |
+
win_max_length (int, optional): Max length of window. Defaults to 4.
|
334 |
+
with_features (bool, optional): Extract features or not
|
335 |
+
"""
|
336 |
+
self.device = device
|
337 |
+
self.sr = sr
|
338 |
+
self.win_max_length = win_max_length
|
339 |
+
self.with_features = with_features
|
340 |
+
|
341 |
+
checkpoint_path = load_model(checkpoint_url, folder_path)
|
342 |
+
|
343 |
+
model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
344 |
+
model_config = AutoConfig.from_pretrained(model_name)
|
345 |
+
|
346 |
+
model_config.out_emo = 7
|
347 |
+
model_config.out_sen = 3
|
348 |
+
model_config.context_length = 199
|
349 |
+
|
350 |
+
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
351 |
+
|
352 |
+
self.model = AudioModelWT.from_pretrained(
|
353 |
+
pretrained_model_name_or_path=model_name, config=model_config
|
354 |
+
)
|
355 |
+
|
356 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
357 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
358 |
+
self.model.to(self.device)
|
359 |
+
|
360 |
+
def preprocess_wave(self, x: torch.Tensor) -> torch.Tensor:
|
361 |
+
"""Extracts features for wav2vec
|
362 |
+
Apply padding to max length of audio
|
363 |
+
|
364 |
+
Args:
|
365 |
+
x (torch.Tensor): Input data
|
366 |
+
|
367 |
+
Returns:
|
368 |
+
np.ndarray: Preprocessed data
|
369 |
+
"""
|
370 |
+
a_data = self.processor(
|
371 |
+
x,
|
372 |
+
sampling_rate=self.sr,
|
373 |
+
return_tensors="pt",
|
374 |
+
padding="max_length",
|
375 |
+
max_length=self.sr * self.win_max_length,
|
376 |
+
)
|
377 |
+
return a_data["input_values"][0]
|
378 |
+
|
379 |
+
def __call__(
|
380 |
+
self, waveform: torch.Tensor
|
381 |
+
) -> tuple[dict[torch.Tensor], torch.Tensor]:
|
382 |
+
"""Extracts acoustic features
|
383 |
+
Apply padding to max length of audio
|
384 |
+
|
385 |
+
Args:
|
386 |
+
wave (torch.Tensor): wave
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
torch.Tensor: Extracted features
|
390 |
+
"""
|
391 |
+
waveform = self.preprocess_wave(waveform).unsqueeze(0).to(self.device)
|
392 |
+
|
393 |
+
with torch.no_grad():
|
394 |
+
if self.with_features:
|
395 |
+
preds, features = self.model(waveform, with_features=self.with_features)
|
396 |
+
else:
|
397 |
+
preds = self.model(waveform, with_features=self.with_features)
|
398 |
+
|
399 |
+
predicts = {
|
400 |
+
"emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(),
|
401 |
+
"sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(),
|
402 |
+
}
|
403 |
+
|
404 |
+
return (
|
405 |
+
(predicts, features.detach().cpu().squeeze())
|
406 |
+
if self.with_features
|
407 |
+
else (predicts, None)
|
408 |
+
)
|
409 |
+
|
410 |
+
|
411 |
+
class Tmodel(nn.Module):
|
412 |
+
def __init__(
|
413 |
+
self,
|
414 |
+
input_size: int = int(1024),
|
415 |
+
activation=nn.SELU(),
|
416 |
+
feature_size1=256,
|
417 |
+
feature_size2=64,
|
418 |
+
num_heads=1,
|
419 |
+
num_layers=2,
|
420 |
+
n_emo=7,
|
421 |
+
n_sent=3,
|
422 |
+
):
|
423 |
+
super(Tmodel, self).__init__()
|
424 |
+
self.feature_text_dynamic = nn.ModuleList(
|
425 |
+
[
|
426 |
+
SelfTransformer(input_size=input_size, num_heads=num_heads)
|
427 |
+
for i in range(num_layers)
|
428 |
+
]
|
429 |
+
)
|
430 |
+
self.fcl = nn.Linear(input_size, feature_size1)
|
431 |
+
self.activation = activation
|
432 |
+
self.feature_emo = nn.Linear(feature_size1, feature_size2)
|
433 |
+
self.feature_sent = nn.Linear(feature_size1, feature_size2)
|
434 |
+
self.fc_emo = nn.Linear(feature_size2, n_emo)
|
435 |
+
self.fc_sent = nn.Linear(feature_size2, n_sent)
|
436 |
+
|
437 |
+
def get_features(self, t):
|
438 |
+
for i, l in enumerate(self.feature_text_dynamic):
|
439 |
+
self.features = l(t)
|
440 |
+
|
441 |
+
def forward(self, t):
|
442 |
+
self.get_features(t)
|
443 |
+
represent = self.activation(torch.mean(t, axis=1))
|
444 |
+
represent = self.activation(self.fcl(represent))
|
445 |
+
represent_emo = self.activation(self.feature_emo(represent))
|
446 |
+
represent_sent = self.activation(self.feature_sent(represent))
|
447 |
+
prob_emo = self.fc_emo(represent_emo)
|
448 |
+
prob_sent = self.fc_sent(represent_sent)
|
449 |
+
return prob_emo, prob_sent
|
450 |
+
|
451 |
+
|
452 |
+
class TextFeatureExtractor:
|
453 |
+
def __init__(
|
454 |
+
self,
|
455 |
+
checkpoint_url: str,
|
456 |
+
folder_path: str,
|
457 |
+
device: torch.device,
|
458 |
+
with_features: bool = True,
|
459 |
+
) -> None:
|
460 |
+
|
461 |
+
self.device = device
|
462 |
+
self.with_features = with_features
|
463 |
+
|
464 |
+
model_name_bert = "julian-schelb/roberta-ner-multilingual"
|
465 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
466 |
+
model_name_bert, add_prefix_space=True
|
467 |
+
)
|
468 |
+
self.model_bert = AutoModel.from_pretrained(model_name_bert)
|
469 |
+
|
470 |
+
checkpoint_path = load_model(checkpoint_url, folder_path)
|
471 |
+
|
472 |
+
self.model = Tmodel()
|
473 |
+
self.model.load_state_dict(
|
474 |
+
torch.load(checkpoint_path, map_location=self.device)
|
475 |
+
)
|
476 |
+
self.model.to(self.device)
|
477 |
+
|
478 |
+
def preprocess_text(self, text: torch.Tensor) -> torch.Tensor:
|
479 |
+
if text != "" and str(text) != "nan":
|
480 |
+
inputs = self.tokenizer(
|
481 |
+
text.lower(),
|
482 |
+
padding="max_length",
|
483 |
+
truncation="longest_first",
|
484 |
+
return_tensors="pt",
|
485 |
+
max_length=6,
|
486 |
+
).to(self.device)
|
487 |
+
with torch.no_grad():
|
488 |
+
self.model_bert = self.model_bert.to(self.device)
|
489 |
+
outputs = (
|
490 |
+
self.model_bert(
|
491 |
+
input_ids=inputs["input_ids"],
|
492 |
+
attention_mask=inputs["attention_mask"],
|
493 |
+
)
|
494 |
+
.last_hidden_state.cpu()
|
495 |
+
.detach()
|
496 |
+
)
|
497 |
+
else:
|
498 |
+
outputs = torch.zeros((1, 6, 1024))
|
499 |
+
return outputs
|
500 |
+
|
501 |
+
def __call__(self, text: torch.Tensor) -> tuple[dict[torch.Tensor], torch.Tensor]:
|
502 |
+
text_features = self.preprocess_text(text)
|
503 |
+
|
504 |
+
with torch.no_grad():
|
505 |
+
if self.with_features:
|
506 |
+
pred_emo, pred_sent = self.model(text_features.float().to(self.device))
|
507 |
+
temporal_features = self.model.features
|
508 |
+
else:
|
509 |
+
pred_emo, pred_sent = self.model(text_features.float().to(self.device))
|
510 |
+
|
511 |
+
predicts = {
|
512 |
+
"emo": F.softmax(pred_emo, dim=-1).detach().cpu().squeeze(),
|
513 |
+
"sen": F.softmax(pred_sent, dim=-1).detach().cpu().squeeze(),
|
514 |
+
}
|
515 |
+
|
516 |
+
return (
|
517 |
+
(predicts, temporal_features.detach().cpu().squeeze())
|
518 |
+
if self.with_features
|
519 |
+
else (predicts, None)
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
class Bottleneck(nn.Module):
|
524 |
+
expansion = 4
|
525 |
+
|
526 |
+
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
|
527 |
+
super(Bottleneck, self).__init__()
|
528 |
+
|
529 |
+
self.conv1 = nn.Conv2d(
|
530 |
+
in_channels,
|
531 |
+
out_channels,
|
532 |
+
kernel_size=1,
|
533 |
+
stride=stride,
|
534 |
+
padding=0,
|
535 |
+
bias=False,
|
536 |
+
)
|
537 |
+
self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
|
538 |
+
|
539 |
+
self.conv2 = nn.Conv2d(
|
540 |
+
out_channels, out_channels, kernel_size=3, padding="same", bias=False
|
541 |
+
)
|
542 |
+
self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)
|
543 |
+
|
544 |
+
self.conv3 = nn.Conv2d(
|
545 |
+
out_channels,
|
546 |
+
out_channels * self.expansion,
|
547 |
+
kernel_size=1,
|
548 |
+
stride=1,
|
549 |
+
padding=0,
|
550 |
+
bias=False,
|
551 |
+
)
|
552 |
+
self.batch_norm3 = nn.BatchNorm2d(
|
553 |
+
out_channels * self.expansion, eps=0.001, momentum=0.99
|
554 |
+
)
|
555 |
+
|
556 |
+
self.i_downsample = i_downsample
|
557 |
+
self.stride = stride
|
558 |
+
self.relu = nn.ReLU()
|
559 |
+
|
560 |
+
def forward(self, x):
|
561 |
+
identity = x.clone()
|
562 |
+
x = self.relu(self.batch_norm1(self.conv1(x)))
|
563 |
+
|
564 |
+
x = self.relu(self.batch_norm2(self.conv2(x)))
|
565 |
+
|
566 |
+
x = self.conv3(x)
|
567 |
+
x = self.batch_norm3(x)
|
568 |
+
|
569 |
+
# downsample if needed
|
570 |
+
if self.i_downsample is not None:
|
571 |
+
identity = self.i_downsample(identity)
|
572 |
+
# add identity
|
573 |
+
x += identity
|
574 |
+
x = self.relu(x)
|
575 |
+
|
576 |
+
return x
|
577 |
+
|
578 |
+
|
579 |
+
class Conv2dSame(torch.nn.Conv2d):
|
580 |
+
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
|
581 |
+
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
582 |
+
|
583 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
584 |
+
ih, iw = x.size()[-2:]
|
585 |
+
|
586 |
+
pad_h = self.calc_same_pad(
|
587 |
+
i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]
|
588 |
+
)
|
589 |
+
pad_w = self.calc_same_pad(
|
590 |
+
i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]
|
591 |
+
)
|
592 |
+
|
593 |
+
if pad_h > 0 or pad_w > 0:
|
594 |
+
x = F.pad(
|
595 |
+
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
596 |
+
)
|
597 |
+
return F.conv2d(
|
598 |
+
x,
|
599 |
+
self.weight,
|
600 |
+
self.bias,
|
601 |
+
self.stride,
|
602 |
+
self.padding,
|
603 |
+
self.dilation,
|
604 |
+
self.groups,
|
605 |
+
)
|
606 |
+
|
607 |
+
|
608 |
+
class ResNet(nn.Module):
|
609 |
+
def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
|
610 |
+
super(ResNet, self).__init__()
|
611 |
+
self.in_channels = 64
|
612 |
+
|
613 |
+
self.conv_layer_s2_same = Conv2dSame(
|
614 |
+
num_channels, 64, 7, stride=2, groups=1, bias=False
|
615 |
+
)
|
616 |
+
self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99)
|
617 |
+
self.relu = nn.ReLU()
|
618 |
+
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
|
619 |
+
|
620 |
+
self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1)
|
621 |
+
self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
|
622 |
+
self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
|
623 |
+
self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
|
624 |
+
|
625 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
626 |
+
self.fc1 = nn.Linear(512 * ResBlock.expansion, 512)
|
627 |
+
self.relu1 = nn.ReLU()
|
628 |
+
self.fc2 = nn.Linear(512, num_classes)
|
629 |
+
|
630 |
+
def extract_features_four(self, x):
|
631 |
+
x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x)))
|
632 |
+
x = self.max_pool(x)
|
633 |
+
# print(x.shape)
|
634 |
+
x = self.layer1(x)
|
635 |
+
x = self.layer2(x)
|
636 |
+
x = self.layer3(x)
|
637 |
+
x = self.layer4(x)
|
638 |
+
return x
|
639 |
+
|
640 |
+
def extract_features(self, x):
|
641 |
+
x = self.extract_features_four(x)
|
642 |
+
x = self.avgpool(x)
|
643 |
+
x = x.reshape(x.shape[0], -1)
|
644 |
+
x = self.fc1(x)
|
645 |
+
return x
|
646 |
+
|
647 |
+
def forward(self, x):
|
648 |
+
x = self.extract_features(x)
|
649 |
+
x = self.relu1(x)
|
650 |
+
x = self.fc2(x)
|
651 |
+
return x
|
652 |
+
|
653 |
+
def _make_layer(self, ResBlock, blocks, planes, stride=1):
|
654 |
+
ii_downsample = None
|
655 |
+
layers = []
|
656 |
+
|
657 |
+
if stride != 1 or self.in_channels != planes * ResBlock.expansion:
|
658 |
+
ii_downsample = nn.Sequential(
|
659 |
+
nn.Conv2d(
|
660 |
+
self.in_channels,
|
661 |
+
planes * ResBlock.expansion,
|
662 |
+
kernel_size=1,
|
663 |
+
stride=stride,
|
664 |
+
bias=False,
|
665 |
+
padding=0,
|
666 |
+
),
|
667 |
+
nn.BatchNorm2d(planes * ResBlock.expansion, eps=0.001, momentum=0.99),
|
668 |
+
)
|
669 |
+
|
670 |
+
layers.append(
|
671 |
+
ResBlock(
|
672 |
+
self.in_channels, planes, i_downsample=ii_downsample, stride=stride
|
673 |
+
)
|
674 |
+
)
|
675 |
+
self.in_channels = planes * ResBlock.expansion
|
676 |
+
|
677 |
+
for i in range(blocks - 1):
|
678 |
+
layers.append(ResBlock(self.in_channels, planes))
|
679 |
+
|
680 |
+
return nn.Sequential(*layers)
|
681 |
+
|
682 |
+
|
683 |
+
def ResNet50(num_classes, channels=3):
|
684 |
+
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels)
|
685 |
+
|
686 |
+
|
687 |
+
class Vmodel(nn.Module):
|
688 |
+
def __init__(
|
689 |
+
self,
|
690 |
+
input_size=512,
|
691 |
+
activation=nn.SELU(),
|
692 |
+
feature_size=64,
|
693 |
+
num_heads=1,
|
694 |
+
num_layers=1,
|
695 |
+
positional_encoding=False,
|
696 |
+
n_emo=7,
|
697 |
+
n_sent=3,
|
698 |
+
):
|
699 |
+
super(Vmodel, self).__init__()
|
700 |
+
|
701 |
+
self.feature_video_dynamic = nn.ModuleList(
|
702 |
+
[
|
703 |
+
TransformerLayer(
|
704 |
+
input_dim=input_size,
|
705 |
+
num_heads=num_heads,
|
706 |
+
positional_encoding=positional_encoding,
|
707 |
+
)
|
708 |
+
for i in range(num_layers)
|
709 |
+
]
|
710 |
+
)
|
711 |
+
|
712 |
+
self.fcl = nn.Linear(input_size, feature_size)
|
713 |
+
self.activation = activation
|
714 |
+
self.feature_emo = nn.Linear(feature_size, feature_size)
|
715 |
+
self.feature_sent = nn.Linear(feature_size, feature_size)
|
716 |
+
self.fc_emo = nn.Linear(feature_size, n_emo)
|
717 |
+
self.fc_sent = nn.Linear(feature_size, n_sent)
|
718 |
+
|
719 |
+
def forward(self, x, with_features=False):
|
720 |
+
for i, l in enumerate(self.feature_video_dynamic):
|
721 |
+
x = l(x, x, x)
|
722 |
+
|
723 |
+
represent = self.activation(torch.mean(x, axis=1))
|
724 |
+
represent = self.activation(self.fcl(represent))
|
725 |
+
represent_emo = self.activation(self.feature_emo(represent))
|
726 |
+
represent_sent = self.activation(self.feature_sent(represent))
|
727 |
+
prob_emo = self.fc_emo(represent_emo)
|
728 |
+
prob_sent = self.fc_sent(represent_sent)
|
729 |
+
|
730 |
+
if with_features:
|
731 |
+
return {"emo": prob_emo, "sen": prob_sent}, x
|
732 |
+
else:
|
733 |
+
return {"emo": prob_emo, "sen": prob_sent}
|
734 |
+
|
735 |
+
|
736 |
+
class VideoModelLoader:
|
737 |
+
def __init__(
|
738 |
+
self,
|
739 |
+
face_checkpoint_url: str,
|
740 |
+
emotion_checkpoint_url: str,
|
741 |
+
emo_sent_checkpoint_url: str,
|
742 |
+
folder_path: str,
|
743 |
+
device: torch.device,
|
744 |
+
) -> None:
|
745 |
+
self.device = device
|
746 |
+
|
747 |
+
# YOLO face recognition model initialization
|
748 |
+
face_model_path = load_model(face_checkpoint_url, folder_path)
|
749 |
+
emotion_video_model_path = load_model(emotion_checkpoint_url, folder_path)
|
750 |
+
emo_sent_video_model_path = load_model(emo_sent_checkpoint_url, folder_path)
|
751 |
+
|
752 |
+
self.face_model = YOLO(face_model_path)
|
753 |
+
|
754 |
+
# EmoAffectet model initialization (static model)
|
755 |
+
self.emo_affectnet_model = ResNet50(num_classes=7, channels=3)
|
756 |
+
self.emo_affectnet_model.load_state_dict(
|
757 |
+
torch.load(emotion_video_model_path, map_location=self.device)
|
758 |
+
)
|
759 |
+
self.emo_affectnet_model.to(self.device).eval()
|
760 |
+
|
761 |
+
# Visual emotion and sentiment recognition model (dynamic model)
|
762 |
+
self.emo_sent_video_model = Vmodel()
|
763 |
+
self.emo_sent_video_model.load_state_dict(
|
764 |
+
torch.load(emo_sent_video_model_path, map_location=self.device)
|
765 |
+
)
|
766 |
+
self.emo_sent_video_model.to(self.device).eval()
|
767 |
+
|
768 |
+
def extract_zeros_features(self):
|
769 |
+
zeros = torch.unsqueeze(torch.zeros((3, 224, 224)), 0).to(self.device)
|
770 |
+
zeros_features = self.emo_affectnet_model.extract_features(zeros)
|
771 |
+
return zeros_features.cpu().detach().numpy()[0]
|
772 |
+
|
773 |
+
|
774 |
+
class VideoFeatureExtractor:
|
775 |
+
def __init__(
|
776 |
+
self,
|
777 |
+
model_loader: VideoModelLoader,
|
778 |
+
file_path: str,
|
779 |
+
target_fps: int = 5,
|
780 |
+
with_features: bool = True,
|
781 |
+
) -> None:
|
782 |
+
self.model_loader = model_loader
|
783 |
+
self.with_features = with_features
|
784 |
+
|
785 |
+
# Video options
|
786 |
+
self.cap = cv2.VideoCapture(file_path)
|
787 |
+
self.w, self.h, self.fps, self.frame_number = (
|
788 |
+
int(self.cap.get(x))
|
789 |
+
for x in (
|
790 |
+
cv2.CAP_PROP_FRAME_WIDTH,
|
791 |
+
cv2.CAP_PROP_FRAME_HEIGHT,
|
792 |
+
cv2.CAP_PROP_FPS,
|
793 |
+
cv2.CAP_PROP_FRAME_COUNT,
|
794 |
+
)
|
795 |
+
)
|
796 |
+
self.dur = self.frame_number / self.fps
|
797 |
+
self.target_fps = target_fps
|
798 |
+
self.frame_interval = int(self.fps / target_fps)
|
799 |
+
|
800 |
+
# Extract zero features if no face found in frame
|
801 |
+
self.zeros_features = self.model_loader.extract_zeros_features()
|
802 |
+
|
803 |
+
# Dictionaries with facial features and faces
|
804 |
+
self.facial_features = {}
|
805 |
+
self.faces = {}
|
806 |
+
|
807 |
+
def preprocess_frame(self, frame: np.ndarray, counter: int) -> None:
|
808 |
+
curr_fr = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
809 |
+
results = self.model_loader.face_model.track(
|
810 |
+
curr_fr,
|
811 |
+
persist=True,
|
812 |
+
imgsz=640,
|
813 |
+
conf=0.01,
|
814 |
+
iou=0.5,
|
815 |
+
augment=False,
|
816 |
+
device=self.model_loader.device,
|
817 |
+
verbose=False,
|
818 |
+
)
|
819 |
+
|
820 |
+
need_features = np.zeros(512)
|
821 |
+
count_face = 0
|
822 |
+
|
823 |
+
if results[0].boxes.xyxy.cpu().tolist() != []:
|
824 |
+
for i in results[0].boxes:
|
825 |
+
idx_box = i.id.int().cpu().tolist()[0] if i.id else -1
|
826 |
+
box = i.xyxy.int().cpu().tolist()[0]
|
827 |
+
startX, startY = max(0, box[0]), max(0, box[1])
|
828 |
+
endX, endY = min(self.w - 1, box[2]), min(self.h - 1, box[3])
|
829 |
+
|
830 |
+
face_region = curr_fr[startY:endY, startX:endX]
|
831 |
+
norm_face_region = pth_processing(Image.fromarray(face_region))
|
832 |
+
with torch.no_grad():
|
833 |
+
curr_features = (
|
834 |
+
self.model_loader.emo_affectnet_model.extract_features(
|
835 |
+
norm_face_region.to(self.model_loader.device)
|
836 |
+
)
|
837 |
+
)
|
838 |
+
need_features += curr_features.cpu().detach().numpy()[0]
|
839 |
+
count_face += 1
|
840 |
+
|
841 |
+
# face_region = cv2.resize(face_region, (224,224), interpolation = cv2.INTER_AREA)
|
842 |
+
# face_region = display_frame_info(face_region, 'Frame: {}'.format(count_face), box_scale=.3)
|
843 |
+
|
844 |
+
if idx_box in self.faces:
|
845 |
+
self.faces[idx_box].update({counter: face_region})
|
846 |
+
else:
|
847 |
+
self.faces[idx_box] = {counter: face_region}
|
848 |
+
|
849 |
+
need_features /= count_face
|
850 |
+
self.facial_features[counter] = need_features
|
851 |
+
else:
|
852 |
+
if counter - 1 in self.facial_features:
|
853 |
+
self.facial_features[counter] = self.facial_features[counter - 1]
|
854 |
+
else:
|
855 |
+
self.facial_features[counter] = self.zeros_features
|
856 |
+
|
857 |
+
def preprocess_video(self) -> None:
|
858 |
+
counter = 0
|
859 |
+
|
860 |
+
while True:
|
861 |
+
ret, frame = self.cap.read()
|
862 |
+
if not ret:
|
863 |
+
break
|
864 |
+
if counter % self.frame_interval == 0:
|
865 |
+
self.preprocess_frame(frame, counter)
|
866 |
+
counter += 1
|
867 |
+
|
868 |
+
def __call__(
|
869 |
+
self, window: dict, win_max_length: int, sr: int = 16000
|
870 |
+
) -> tuple[dict[torch.Tensor], torch.Tensor]:
|
871 |
+
|
872 |
+
curr_idx_frames = get_idx_frames_in_windows(
|
873 |
+
list(self.facial_features.keys()), window, self.fps, sr
|
874 |
+
)
|
875 |
+
|
876 |
+
video_features = np.array(list(self.facial_features.values()))
|
877 |
+
|
878 |
+
curr_features = video_features[curr_idx_frames, :]
|
879 |
+
|
880 |
+
if len(curr_features) < self.target_fps * win_max_length:
|
881 |
+
diff = self.target_fps * win_max_length - len(curr_features)
|
882 |
+
curr_features = np.concatenate(
|
883 |
+
[curr_features, [curr_features[-1]] * diff], axis=0
|
884 |
+
)
|
885 |
+
|
886 |
+
curr_features = (
|
887 |
+
torch.FloatTensor(curr_features).unsqueeze(0).to(self.model_loader.device)
|
888 |
+
)
|
889 |
+
|
890 |
+
with torch.no_grad():
|
891 |
+
if self.with_features:
|
892 |
+
preds, features = self.model_loader.emo_sent_video_model(
|
893 |
+
curr_features, with_features=self.with_features
|
894 |
+
)
|
895 |
+
else:
|
896 |
+
preds = self.model_loader.emo_sent_video_model(
|
897 |
+
curr_features, with_features=self.with_features
|
898 |
+
)
|
899 |
+
|
900 |
+
predicts = {
|
901 |
+
"emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(),
|
902 |
+
"sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(),
|
903 |
+
}
|
904 |
+
|
905 |
+
return (
|
906 |
+
(predicts, features.detach().cpu().squeeze())
|
907 |
+
if self.with_features
|
908 |
+
else (predicts, None)
|
909 |
+
)
|
app/plots.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: plots.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Plotting functions.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
|
14 |
+
def plot_audio(time_axis, waveform, frame_indices, fps, figsize=(10, 4)) -> plt.Figure:
|
15 |
+
frame_times = np.array(frame_indices) / fps
|
16 |
+
|
17 |
+
fig, ax = plt.subplots(figsize=figsize)
|
18 |
+
ax.plot(time_axis, waveform[0])
|
19 |
+
ax.set_xlabel("Time (frames)")
|
20 |
+
ax.set_ylabel("Amplitude")
|
21 |
+
ax.grid(True)
|
22 |
+
|
23 |
+
ax.set_xticks(frame_times)
|
24 |
+
ax.set_xticklabels([f"{int(frame_time * fps) + 1}" for frame_time in frame_times])
|
25 |
+
|
26 |
+
fig.tight_layout()
|
27 |
+
|
28 |
+
return fig
|
29 |
+
|
30 |
+
|
31 |
+
def plot_images(image_paths):
|
32 |
+
fig, axes = plt.subplots(1, len(image_paths), figsize=(12, 2))
|
33 |
+
|
34 |
+
for ax, img_path in zip(axes, image_paths):
|
35 |
+
ax.imshow(img_path)
|
36 |
+
ax.axis("off")
|
37 |
+
|
38 |
+
fig.tight_layout()
|
39 |
+
return fig
|
40 |
+
|
41 |
+
|
42 |
+
def get_evenly_spaced_frame_indices(total_frames, num_frames=10):
|
43 |
+
if total_frames <= num_frames:
|
44 |
+
return list(range(total_frames))
|
45 |
+
|
46 |
+
step = total_frames / num_frames
|
47 |
+
return [int(np.round(i * step)) for i in range(num_frames)]
|
48 |
+
|
49 |
+
|
50 |
+
def plot_predictions(
|
51 |
+
df: pd.DataFrame,
|
52 |
+
column: str,
|
53 |
+
title: str,
|
54 |
+
y_labels: list[str],
|
55 |
+
figsize: tuple[int, int],
|
56 |
+
x_ticks: list[int],
|
57 |
+
line_width: float = 2.0,
|
58 |
+
) -> None:
|
59 |
+
fig, ax = plt.subplots(figsize=figsize)
|
60 |
+
|
61 |
+
ax.plot(df.index, df[column], linestyle="dotted", linewidth=line_width)
|
62 |
+
ax.set_title(title)
|
63 |
+
ax.set_xlabel("Frames")
|
64 |
+
ax.set_ylabel(title)
|
65 |
+
|
66 |
+
ax.set_xticks(x_ticks)
|
67 |
+
ax.set_yticks(range(len(y_labels)))
|
68 |
+
ax.set_yticklabels(y_labels)
|
69 |
+
|
70 |
+
ax.grid(True)
|
71 |
+
fig.tight_layout()
|
72 |
+
return fig
|
73 |
+
|
74 |
+
|
75 |
+
def display_frame_info(img, text, margin=1.0, box_scale=1.0, scale=1.5):
|
76 |
+
img_copy = img.copy()
|
77 |
+
img_h, img_w, _ = img_copy.shape
|
78 |
+
line_width = int(min(img_h, img_w) * 0.001)
|
79 |
+
thickness = max(int(line_width / 3), 1)
|
80 |
+
|
81 |
+
font_face = cv2.FONT_HERSHEY_SIMPLEX
|
82 |
+
font_color = (0, 0, 0)
|
83 |
+
font_scale = thickness / scale
|
84 |
+
|
85 |
+
t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]
|
86 |
+
|
87 |
+
margin_n = int(t_h * margin)
|
88 |
+
sub_img = img_copy[
|
89 |
+
0 + margin_n : 0 + margin_n + t_h + int(2 * t_h * box_scale),
|
90 |
+
img_w - t_w - margin_n - int(2 * t_h * box_scale) : img_w - margin_n,
|
91 |
+
]
|
92 |
+
|
93 |
+
white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
|
94 |
+
|
95 |
+
img_copy[
|
96 |
+
0 + margin_n : 0 + margin_n + t_h + int(2 * t_h * box_scale),
|
97 |
+
img_w - t_w - margin_n - int(2 * t_h * box_scale) : img_w - margin_n,
|
98 |
+
] = cv2.addWeighted(sub_img, 0.5, white_rect, 0.5, 1.0)
|
99 |
+
|
100 |
+
cv2.putText(
|
101 |
+
img=img_copy,
|
102 |
+
text=text,
|
103 |
+
org=(
|
104 |
+
img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
|
105 |
+
0 + margin_n + t_h + int(2 * t_h * box_scale) // 2,
|
106 |
+
),
|
107 |
+
fontFace=font_face,
|
108 |
+
fontScale=font_scale,
|
109 |
+
color=font_color,
|
110 |
+
thickness=thickness,
|
111 |
+
lineType=cv2.LINE_AA,
|
112 |
+
bottomLeftOrigin=False,
|
113 |
+
)
|
114 |
+
|
115 |
+
return img_copy
|
app/requirements_app.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: requirements_app.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Project requirements for the Gradio app.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import polars as pl
|
9 |
+
|
10 |
+
# Importing necessary components for the Gradio app
|
11 |
+
from app.config import config_data
|
12 |
+
|
13 |
+
|
14 |
+
def read_requirements(file_path="requirements.txt"):
|
15 |
+
with open(file_path, "r") as file:
|
16 |
+
lines = file.readlines()
|
17 |
+
|
18 |
+
data = []
|
19 |
+
|
20 |
+
pypi = (
|
21 |
+
lambda x: f"<a href='https://pypi.org/project/{x}' target='_blank'>"
|
22 |
+
+ f"<img src='https://img.shields.io/pypi/v/{x}' alt='PyPI' /></a>"
|
23 |
+
)
|
24 |
+
|
25 |
+
data = [
|
26 |
+
{
|
27 |
+
config_data.Requirements_LIBRARY: split_line[0],
|
28 |
+
config_data.Requirements_RECOMMENDED_VERSION: split_line[1],
|
29 |
+
config_data.Requirements_CURRENT_VERSION: pypi(split_line[0]),
|
30 |
+
}
|
31 |
+
for line in lines
|
32 |
+
if (split_line := line.strip().split("==")) and len(split_line) == 2
|
33 |
+
]
|
34 |
+
|
35 |
+
df = pl.DataFrame(data)
|
36 |
+
|
37 |
+
return df
|
app/tabs.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: tabs.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Gradio app tabs - Contains the definition of various tabs for the Gradio app interface.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Importing necessary components for the Gradio app
|
11 |
+
from app.description import DESCRIPTION
|
12 |
+
from app.config import config_data
|
13 |
+
from app.components import html_message
|
14 |
+
from app.requirements_app import read_requirements
|
15 |
+
|
16 |
+
|
17 |
+
def app_tab():
|
18 |
+
gr.Markdown(value=DESCRIPTION)
|
19 |
+
|
20 |
+
with gr.Row(
|
21 |
+
visible=True,
|
22 |
+
render=True,
|
23 |
+
variant="default",
|
24 |
+
elem_classes="app-container",
|
25 |
+
):
|
26 |
+
with gr.Column(
|
27 |
+
visible=True,
|
28 |
+
render=True,
|
29 |
+
variant="default",
|
30 |
+
elem_classes="video-container",
|
31 |
+
):
|
32 |
+
video = gr.Video(
|
33 |
+
label=config_data.Labels_VIDEO,
|
34 |
+
show_label=True,
|
35 |
+
interactive=True,
|
36 |
+
visible=True,
|
37 |
+
mirror_webcam=True,
|
38 |
+
include_audio=True,
|
39 |
+
elem_classes="video",
|
40 |
+
autoplay=False,
|
41 |
+
)
|
42 |
+
|
43 |
+
with gr.Row(
|
44 |
+
visible=True,
|
45 |
+
render=True,
|
46 |
+
variant="default",
|
47 |
+
elem_classes="submit-container",
|
48 |
+
):
|
49 |
+
clear = gr.Button(
|
50 |
+
value=config_data.OtherMessages_CLEAR,
|
51 |
+
interactive=False,
|
52 |
+
icon=config_data.Path_APP
|
53 |
+
/ config_data.StaticPaths_IMAGES
|
54 |
+
/ "clear.ico",
|
55 |
+
visible=True,
|
56 |
+
elem_classes="clear",
|
57 |
+
)
|
58 |
+
submit = gr.Button(
|
59 |
+
value=config_data.OtherMessages_SUBMIT,
|
60 |
+
interactive=False,
|
61 |
+
icon=config_data.Path_APP
|
62 |
+
/ config_data.StaticPaths_IMAGES
|
63 |
+
/ "submit.ico",
|
64 |
+
visible=True,
|
65 |
+
elem_classes="submit",
|
66 |
+
)
|
67 |
+
|
68 |
+
gr.Examples(
|
69 |
+
[
|
70 |
+
"videos/1.mp4",
|
71 |
+
"videos/2.mp4",
|
72 |
+
],
|
73 |
+
[video],
|
74 |
+
)
|
75 |
+
|
76 |
+
with gr.Column(
|
77 |
+
visible=True,
|
78 |
+
render=True,
|
79 |
+
variant="default",
|
80 |
+
elem_classes="results-container",
|
81 |
+
):
|
82 |
+
noti_results = html_message(
|
83 |
+
message=config_data.InformationMessages_NOTI_RESULTS[0],
|
84 |
+
error=True,
|
85 |
+
visible=True,
|
86 |
+
)
|
87 |
+
|
88 |
+
waveform = gr.Plot(
|
89 |
+
value=None,
|
90 |
+
label=config_data.Labels_WAVEFORM,
|
91 |
+
show_label=True,
|
92 |
+
visible=False,
|
93 |
+
elem_classes="audio",
|
94 |
+
)
|
95 |
+
|
96 |
+
faces = gr.Plot(
|
97 |
+
value=None,
|
98 |
+
label=config_data.Labels_FACE_IMAGES,
|
99 |
+
show_label=True,
|
100 |
+
visible=False,
|
101 |
+
elem_classes="imgs",
|
102 |
+
)
|
103 |
+
|
104 |
+
emotion_stats = gr.Plot(
|
105 |
+
value=None,
|
106 |
+
label=config_data.Labels_EMO_STATS,
|
107 |
+
show_label=True,
|
108 |
+
visible=False,
|
109 |
+
elem_classes="emo-stats",
|
110 |
+
)
|
111 |
+
|
112 |
+
sent_stats = gr.Plot(
|
113 |
+
value=None,
|
114 |
+
label=config_data.Labels_SENT_STATS,
|
115 |
+
show_label=True,
|
116 |
+
visible=False,
|
117 |
+
elem_classes="sent-stats",
|
118 |
+
)
|
119 |
+
|
120 |
+
return (
|
121 |
+
video,
|
122 |
+
clear,
|
123 |
+
submit,
|
124 |
+
noti_results,
|
125 |
+
waveform,
|
126 |
+
faces,
|
127 |
+
emotion_stats,
|
128 |
+
sent_stats,
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
def settings_app_tab():
|
133 |
+
pass
|
134 |
+
|
135 |
+
|
136 |
+
def about_app_tab():
|
137 |
+
pass
|
138 |
+
|
139 |
+
|
140 |
+
def about_authors_tab():
|
141 |
+
pass
|
142 |
+
|
143 |
+
|
144 |
+
def requirements_app_tab():
|
145 |
+
reqs = read_requirements()
|
146 |
+
|
147 |
+
return gr.Dataframe(
|
148 |
+
headers=reqs.columns,
|
149 |
+
value=reqs,
|
150 |
+
datatype=["markdown"] * len(reqs.columns),
|
151 |
+
visible=True,
|
152 |
+
elem_classes="requirements-dataframe",
|
153 |
+
type="polars",
|
154 |
+
)
|
app/utils.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File: utils.py
|
3 |
+
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
|
4 |
+
Description: Utility functions.
|
5 |
+
License: MIT License
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import bisect
|
12 |
+
import re
|
13 |
+
import requests
|
14 |
+
from torchvision import transforms
|
15 |
+
from PIL import Image
|
16 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
17 |
+
from pathlib import Path
|
18 |
+
from contextlib import suppress
|
19 |
+
from urllib.parse import urlparse
|
20 |
+
|
21 |
+
from typing import Callable
|
22 |
+
|
23 |
+
|
24 |
+
def load_model(
|
25 |
+
model_url: str, folder_path: str, force_reload: bool = False
|
26 |
+
) -> str | None:
|
27 |
+
|
28 |
+
file_name = Path(urlparse(model_url).path).name
|
29 |
+
file_path = Path(folder_path) / file_name
|
30 |
+
|
31 |
+
if file_path.exists() and not force_reload:
|
32 |
+
return str(file_path)
|
33 |
+
|
34 |
+
with suppress(Exception), requests.get(model_url, stream=True) as response:
|
35 |
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
36 |
+
|
37 |
+
with file_path.open("wb") as file:
|
38 |
+
for chunk in response.iter_content(chunk_size=8192):
|
39 |
+
file.write(chunk)
|
40 |
+
|
41 |
+
return str(file_path)
|
42 |
+
|
43 |
+
return None
|
44 |
+
|
45 |
+
|
46 |
+
def readetect_speech(
|
47 |
+
file_path: str,
|
48 |
+
read_audio: Callable,
|
49 |
+
get_speech_timestamps: Callable,
|
50 |
+
vad_model: torch.jit.ScriptModule,
|
51 |
+
sr: int = 16000,
|
52 |
+
) -> list[dict]:
|
53 |
+
wav = read_audio(file_path, sampling_rate=sr)
|
54 |
+
# get speech timestamps from full audio file
|
55 |
+
speech_timestamps = get_speech_timestamps(wav, vad_model, sampling_rate=sr)
|
56 |
+
|
57 |
+
return wav, speech_timestamps
|
58 |
+
|
59 |
+
|
60 |
+
def calculate_mode(series):
|
61 |
+
mode = series.mode()
|
62 |
+
return mode[0] if not mode.empty else None
|
63 |
+
|
64 |
+
|
65 |
+
def pth_processing(fp):
|
66 |
+
class PreprocessInput(torch.nn.Module):
|
67 |
+
def init(self):
|
68 |
+
super(PreprocessInput, self).init()
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = x.to(torch.float32)
|
72 |
+
x = torch.flip(x, dims=(0,))
|
73 |
+
x[0, :, :] -= 91.4953
|
74 |
+
x[1, :, :] -= 103.8827
|
75 |
+
x[2, :, :] -= 131.0912
|
76 |
+
return x
|
77 |
+
|
78 |
+
def get_img_torch(img, target_size=(224, 224)):
|
79 |
+
transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()])
|
80 |
+
img = img.resize(target_size, Image.Resampling.NEAREST)
|
81 |
+
img = transform(img)
|
82 |
+
img = torch.unsqueeze(img, 0)
|
83 |
+
return img
|
84 |
+
|
85 |
+
return get_img_torch(fp)
|
86 |
+
|
87 |
+
|
88 |
+
def get_idx_frames_in_windows(
|
89 |
+
frames: list[int], window: dict, fps: int, sr: int = 16000
|
90 |
+
) -> list[list]:
|
91 |
+
|
92 |
+
frames_in_windows = [
|
93 |
+
idx
|
94 |
+
for idx, frame in enumerate(frames)
|
95 |
+
if window["start"] * fps / sr <= frame < window["end"] * fps / sr
|
96 |
+
]
|
97 |
+
return frames_in_windows
|
98 |
+
|
99 |
+
|
100 |
+
# Maxim code
|
101 |
+
def slice_audio(
|
102 |
+
start_time: float,
|
103 |
+
end_time: float,
|
104 |
+
win_max_length: float,
|
105 |
+
win_shift: float,
|
106 |
+
win_min_length: float,
|
107 |
+
) -> list[dict]:
|
108 |
+
"""Slices audio on windows
|
109 |
+
|
110 |
+
Args:
|
111 |
+
start_time (float): Start time of audio
|
112 |
+
end_time (float): End time of audio
|
113 |
+
win_max_length (float): Window max length
|
114 |
+
win_shift (float): Window shift
|
115 |
+
win_min_length (float): Window min length
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
list[dict]: List of dict with timings, f.e.: {'start': 0, 'end': 12}
|
119 |
+
"""
|
120 |
+
|
121 |
+
if end_time < start_time:
|
122 |
+
return []
|
123 |
+
elif (end_time - start_time) > win_max_length:
|
124 |
+
timings = []
|
125 |
+
while start_time < end_time:
|
126 |
+
end_time_chunk = start_time + win_max_length
|
127 |
+
if end_time_chunk < end_time:
|
128 |
+
timings.append({"start": start_time, "end": end_time_chunk})
|
129 |
+
elif end_time_chunk == end_time: # if tail exact `win_max_length` seconds
|
130 |
+
timings.append({"start": start_time, "end": end_time_chunk})
|
131 |
+
break
|
132 |
+
else: # if tail less then `win_max_length` seconds
|
133 |
+
if (
|
134 |
+
end_time - start_time < win_min_length
|
135 |
+
): # if tail less then `win_min_length` seconds
|
136 |
+
break
|
137 |
+
|
138 |
+
timings.append({"start": start_time, "end": end_time})
|
139 |
+
break
|
140 |
+
|
141 |
+
start_time += win_shift
|
142 |
+
return timings
|
143 |
+
else:
|
144 |
+
return [{"start": start_time, "end": end_time}]
|
145 |
+
|
146 |
+
|
147 |
+
def convert_video_to_audio(file_path: str, sr: int = 16000) -> str:
|
148 |
+
path_save = file_path.split(".")[0] + ".wav"
|
149 |
+
if not os.path.exists(path_save):
|
150 |
+
ffmpeg_command = f"ffmpeg -y -i {file_path} -async 1 -vn -acodec pcm_s16le -ar {sr} {path_save}"
|
151 |
+
subprocess.call(ffmpeg_command, shell=True)
|
152 |
+
|
153 |
+
return path_save
|
154 |
+
|
155 |
+
|
156 |
+
def find_nearest_frames(target_frames, all_frames):
|
157 |
+
nearest_frames = []
|
158 |
+
for frame in target_frames:
|
159 |
+
pos = bisect.bisect_left(all_frames, frame)
|
160 |
+
if pos == 0:
|
161 |
+
nearest_frame = all_frames[0]
|
162 |
+
elif pos == len(all_frames):
|
163 |
+
nearest_frame = all_frames[-1]
|
164 |
+
else:
|
165 |
+
before = all_frames[pos - 1]
|
166 |
+
after = all_frames[pos]
|
167 |
+
nearest_frame = before if frame - before <= after - frame else after
|
168 |
+
nearest_frames.append(nearest_frame)
|
169 |
+
return nearest_frames
|
170 |
+
|
171 |
+
|
172 |
+
def find_intersections(
|
173 |
+
x: list[dict], y: list[dict], min_length: float = 0
|
174 |
+
) -> list[dict]:
|
175 |
+
"""Find intersections of two lists of dicts with intervals, preserving structure of `x` and adding intersection info
|
176 |
+
|
177 |
+
Args:
|
178 |
+
x (list[dict]): First list of intervals
|
179 |
+
y (list[dict]): Second list of intervals
|
180 |
+
min_length (float, optional): Minimum length of intersection. Defaults to 0.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
list[dict]: Windows with intersections, maintaining structure of `x`, and indicating intersection presence.
|
184 |
+
"""
|
185 |
+
timings = []
|
186 |
+
j = 0
|
187 |
+
|
188 |
+
for interval_x in x:
|
189 |
+
original_start = int(interval_x["start"])
|
190 |
+
original_end = int(interval_x["end"])
|
191 |
+
intersections_found = False
|
192 |
+
|
193 |
+
while j < len(y) and y[j]["end"] < original_start:
|
194 |
+
j += 1 # Skip any intervals in `y` that end before the current interval in `x` starts
|
195 |
+
|
196 |
+
# Check for all overlapping intervals in `y`
|
197 |
+
temp_j = (
|
198 |
+
j # Temporary pointer to check intersections within `y` for current `x`
|
199 |
+
)
|
200 |
+
while temp_j < len(y) and y[temp_j]["start"] <= original_end:
|
201 |
+
# Calculate the intersection between `x[i]` and `y[j]`
|
202 |
+
intersection_start = max(original_start, y[temp_j]["start"])
|
203 |
+
intersection_end = min(original_end, y[temp_j]["end"])
|
204 |
+
|
205 |
+
if (
|
206 |
+
intersection_start < intersection_end
|
207 |
+
and (intersection_end - intersection_start) >= min_length
|
208 |
+
):
|
209 |
+
timings.append(
|
210 |
+
{
|
211 |
+
"original_start": original_start,
|
212 |
+
"original_end": original_end,
|
213 |
+
"start": intersection_start,
|
214 |
+
"end": intersection_end,
|
215 |
+
"speech": True,
|
216 |
+
}
|
217 |
+
)
|
218 |
+
intersections_found = True
|
219 |
+
|
220 |
+
temp_j += 1 # Move to the next interval in `y` for further intersections
|
221 |
+
|
222 |
+
# If no intersections were found, add the interval with `intersected` set to False
|
223 |
+
if not intersections_found:
|
224 |
+
timings.append(
|
225 |
+
{
|
226 |
+
"original_start": original_start,
|
227 |
+
"original_end": original_end,
|
228 |
+
"start": None,
|
229 |
+
"end": None,
|
230 |
+
"speech": False,
|
231 |
+
}
|
232 |
+
)
|
233 |
+
|
234 |
+
return timings
|
235 |
+
|
236 |
+
|
237 |
+
# Anastasia code
|
238 |
+
class ASRModel:
|
239 |
+
def __init__(self, checkpoint_path: str, device: torch.device):
|
240 |
+
self.processor = WhisperProcessor.from_pretrained(checkpoint_path)
|
241 |
+
self.model = WhisperForConditionalGeneration.from_pretrained(
|
242 |
+
checkpoint_path
|
243 |
+
).to(device)
|
244 |
+
self.device = device
|
245 |
+
self.model.config.forced_decoder_ids = None
|
246 |
+
|
247 |
+
def __call__(
|
248 |
+
self, sample: torch.Tensor, audio_windows: dict, sr: int = 16000
|
249 |
+
) -> tuple:
|
250 |
+
texts = []
|
251 |
+
|
252 |
+
for t in range(len(audio_windows)):
|
253 |
+
input_features = self.processor(
|
254 |
+
sample[audio_windows[t]["start"] : audio_windows[t]["end"]],
|
255 |
+
sampling_rate=sr,
|
256 |
+
return_tensors="pt",
|
257 |
+
).input_features
|
258 |
+
predicted_ids = self.model.generate(input_features.to(self.device))
|
259 |
+
transcription = self.processor.batch_decode(
|
260 |
+
predicted_ids, skip_special_tokens=False
|
261 |
+
)
|
262 |
+
texts.append(re.findall(r"> ([^<>]+)", transcription[0]))
|
263 |
+
|
264 |
+
# for drawing
|
265 |
+
input_features = self.processor(
|
266 |
+
sample, sampling_rate=sr, return_tensors="pt"
|
267 |
+
).input_features
|
268 |
+
predicted_ids = self.model.generate(input_features.to(self.device))
|
269 |
+
transcription = self.processor.batch_decode(
|
270 |
+
predicted_ids, skip_special_tokens=False
|
271 |
+
)
|
272 |
+
total_text = re.findall(r"> ([^<>]+)", transcription[0])
|
273 |
+
|
274 |
+
return texts, total_text
|
275 |
+
|
276 |
+
|
277 |
+
def convert_webm_to_mp4(input_file):
|
278 |
+
|
279 |
+
path_save = input_file.split(".")[0] + ".mp4"
|
280 |
+
|
281 |
+
if not os.path.exists(path_save):
|
282 |
+
ff_video = "ffmpeg -i {} -c:v copy -c:a aac -strict experimental {}".format(
|
283 |
+
input_file, path_save
|
284 |
+
)
|
285 |
+
subprocess.call(ff_video, shell=True)
|
286 |
+
|
287 |
+
return path_save
|
config.toml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[AppSettings]
|
2 |
+
APP_VERSION = "0.0.1"
|
3 |
+
CSS_PATH = "app.css"
|
4 |
+
|
5 |
+
[General]
|
6 |
+
SR = 16000
|
7 |
+
START_TIME = 0
|
8 |
+
WIN_MAX_LENGTH = 4
|
9 |
+
WIN_SHIFT = 2
|
10 |
+
WIN_MIN_LENGTH = 2
|
11 |
+
DICT_EMO = [
|
12 |
+
"Neutral",
|
13 |
+
"Happy",
|
14 |
+
"Sad",
|
15 |
+
"Anger",
|
16 |
+
"Surprise",
|
17 |
+
"Disgust",
|
18 |
+
"Fear",
|
19 |
+
]
|
20 |
+
DICT_SENT = [
|
21 |
+
"Negative",
|
22 |
+
"Neutral",
|
23 |
+
"Positive",
|
24 |
+
]
|
25 |
+
|
26 |
+
[InformationMessages]
|
27 |
+
NOTI_RESULTS = [
|
28 |
+
"Upload or record video",
|
29 |
+
"Video uploaded, you can perform calculations",
|
30 |
+
]
|
31 |
+
|
32 |
+
[OtherMessages]
|
33 |
+
CLEAR = "Clear"
|
34 |
+
SUBMIT = "Calculate"
|
35 |
+
|
36 |
+
[Labels]
|
37 |
+
VIDEO = "Video"
|
38 |
+
FACE_IMAGES = "Face images"
|
39 |
+
WAVEFORM = "Waveform"
|
40 |
+
EMO_STATS = "Statistics of emotions"
|
41 |
+
SENT_STATS = "Statistics of sentiments"
|
42 |
+
|
43 |
+
[TabCreators]
|
44 |
+
"β App" = "app_tab"
|
45 |
+
"βοΈ Settings" = "settings_app_tab"
|
46 |
+
"π‘ About App" = "about_app_tab"
|
47 |
+
"π Authors" = "about_authors_tab"
|
48 |
+
"π Requirements" = "requirements_app_tab"
|
49 |
+
|
50 |
+
[StaticPaths]
|
51 |
+
MODELS = "models"
|
52 |
+
IMAGES = "images"
|
53 |
+
WEIGHTS = "weights"
|
54 |
+
VAD_MODEL = "snakers4/silero-vad"
|
55 |
+
HF_MODELS = "https://huggingface.co/ElenaRyumina/MASAI_models/resolve/main/"
|
56 |
+
EMO_AFFECTNET_WEIGHTS = "emo_affectnet_weights.pt"
|
57 |
+
EMO_SENT_AUDIO_WEIGHTS = "emo_sent_audio_weights.pth"
|
58 |
+
EMO_SENT_TEXT_WEIGHTS = "emo_sent_text_weights.pth"
|
59 |
+
EMO_SENT_VIDEO_WEIGHTS = "emo_sent_video_weights.pth"
|
60 |
+
YOLOV8N_FACE = "yolov8n-face.pt"
|
61 |
+
OPENAI_WHISPER = "openai/whisper-base"
|
62 |
+
|
63 |
+
[Requirements]
|
64 |
+
LIBRARY = "Library"
|
65 |
+
RECOMMENDED_VERSION = "Recommended version"
|
66 |
+
CURRENT_VERSION = "Current version"
|
images/clear.ico
ADDED
images/submit.ico
ADDED
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.4.0
|
2 |
+
gradio_client==1.4.2
|
3 |
+
polars==1.12.0
|
4 |
+
torch==2.2.2
|
5 |
+
torchaudio==2.2.2
|
6 |
+
opencv-contrib-python==4.10.0.84
|
7 |
+
ultralytics==8.3.26
|
8 |
+
lapx==0.5.11
|
9 |
+
transformers==4.46.1
|
10 |
+
pillow==11.0.0
|
11 |
+
pandas==2.2.3
|
12 |
+
numpy==1.26.4
|
13 |
+
matplotlib==3.9.2
|