diff --git a/.flake8 b/.flake8
index ceb6dd59daad4b169d864b5d04e83b2a9c3b9769..4ad227243f71bcc84b6253f65f4cc06d98bc9f42 100644
--- a/.flake8
+++ b/.flake8
@@ -18,4 +18,4 @@ exclude =
dist,
.venv
pad*.py
-max-complexity = 25
+max-complexity = 25
\ No newline at end of file
diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..1e8a20b6bd8a30656a0d54968fa8b6ee5461b5bf 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -25,7 +25,6 @@
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+checkpoints/BFM_Fitting/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text
+checkpoints/BFM_Fitting/BFM09_model_info.mat filter=lfs diff=lfs merge=lfs -text
+checkpoints/facevid2vid_00189-model.pth.tar filter=lfs diff=lfs merge=lfs -text
+checkpoints/mapping_00229-model.pth.tar filter=lfs diff=lfs merge=lfs -text
+checkpoints/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
+examples/driven_audio/chinese_news.wav filter=lfs diff=lfs merge=lfs -text
+examples/driven_audio/deyu.wav filter=lfs diff=lfs merge=lfs -text
+examples/driven_audio/eluosi.wav filter=lfs diff=lfs merge=lfs -text
+examples/driven_audio/fayu.wav filter=lfs diff=lfs merge=lfs -text
+examples/driven_audio/imagine.wav filter=lfs diff=lfs merge=lfs -text
+examples/driven_audio/japanese.wav filter=lfs diff=lfs merge=lfs -text
+examples/source_image/art_16.png filter=lfs diff=lfs merge=lfs -text
+examples/source_image/art_17.png filter=lfs diff=lfs merge=lfs -text
+examples/source_image/art_3.png filter=lfs diff=lfs merge=lfs -text
+examples/source_image/art_4.png filter=lfs diff=lfs merge=lfs -text
+examples/source_image/art_5.png filter=lfs diff=lfs merge=lfs -text
+examples/source_image/art_8.png filter=lfs diff=lfs merge=lfs -text
+examples/source_image/art_9.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..400ce70510811cd910fdd17d2f2ce1fb97123562 100644
--- a/.gitignore
+++ b/.gitignore
@@ -0,0 +1,159 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+results/
+checkpoints/
+gradio_cached_examples/
+gfpgan/
+start.sh
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index 97ad13d964d051e4bfdd255a668c209120b1ada4..5ddc6e3d8b246534a58f9612a88b309fa7e10795 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,28 +1,59 @@
-# 此Dockerfile适用于“无本地模型”的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
-# 如何构建: 先修改 `config.py`, 然后 docker build -t gpt-academic .
-# 如何运行: docker run --rm -it --net=host gpt-academic
-FROM python:3.11
-
-RUN echo '[global]' > /etc/pip.conf && \
- echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
- echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
-
-
-WORKDIR /gpt
-
-
-
-
-# 安装依赖
-COPY requirements.txt ./
-COPY ./docs/gradio-3.32.2-py3-none-any.whl ./docs/gradio-3.32.2-py3-none-any.whl
-RUN pip3 install -r requirements.txt
-# 装载项目文件
-COPY . .
-RUN pip3 install -r requirements.txt
-
-# 可选步骤,用于预热模块
-RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
-
-# 启动
-CMD ["python3", "-u", "main.py"]
+FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && \
+ apt-get upgrade -y && \
+ apt-get install -y --no-install-recommends \
+ git \
+ zip \
+ unzip \
+ git-lfs \
+ wget \
+ curl \
+ # ffmpeg \
+ ffmpeg \
+ x264 \
+ # python build dependencies \
+ build-essential \
+ libssl-dev \
+ zlib1g-dev \
+ libbz2-dev \
+ libreadline-dev \
+ libsqlite3-dev \
+ libncursesw5-dev \
+ xz-utils \
+ tk-dev \
+ libxml2-dev \
+ libxmlsec1-dev \
+ libffi-dev \
+ liblzma-dev && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN useradd -m -u 1000 user
+USER user
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:${PATH}
+WORKDIR ${HOME}/app
+
+RUN curl https://pyenv.run | bash
+ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
+ENV PYTHON_VERSION=3.10.9
+RUN pyenv install ${PYTHON_VERSION} && \
+ pyenv global ${PYTHON_VERSION} && \
+ pyenv rehash && \
+ pip install --no-cache-dir -U pip setuptools wheel
+
+RUN pip install --no-cache-dir -U torch==1.12.1 torchvision==0.13.1
+COPY --chown=1000 requirements.txt /tmp/requirements.txt
+RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
+
+COPY --chown=1000 . ${HOME}/app
+RUN ls -a
+ENV PYTHONPATH=${HOME}/app \
+ PYTHONUNBUFFERED=1 \
+ GRADIO_ALLOW_FLAGGING=never \
+ GRADIO_NUM_PORTS=1 \
+ GRADIO_SERVER_NAME=0.0.0.0 \
+ GRADIO_THEME=huggingface \
+ SYSTEM=spaces
+CMD ["python", "app.py"]
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
index f288702d2fa16d3cdf0035b15a9fcbc552cd88e7..b2a615ac931ce1e81df51deb56c3df2414b59e63 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,674 +1,21 @@
- GNU GENERAL PUBLIC LICENSE
- Version 3, 29 June 2007
-
- Copyright (C) 2007 Free Software Foundation, Inc.
- Everyone is permitted to copy and distribute verbatim copies
- of this license document, but changing it is not allowed.
-
- Preamble
-
- The GNU General Public License is a free, copyleft license for
-software and other kinds of works.
-
- The licenses for most software and other practical works are designed
-to take away your freedom to share and change the works. By contrast,
-the GNU General Public License is intended to guarantee your freedom to
-share and change all versions of a program--to make sure it remains free
-software for all its users. We, the Free Software Foundation, use the
-GNU General Public License for most of our software; it applies also to
-any other work released this way by its authors. You can apply it to
-your programs, too.
-
- When we speak of free software, we are referring to freedom, not
-price. Our General Public Licenses are designed to make sure that you
-have the freedom to distribute copies of free software (and charge for
-them if you wish), that you receive source code or can get it if you
-want it, that you can change the software or use pieces of it in new
-free programs, and that you know you can do these things.
-
- To protect your rights, we need to prevent others from denying you
-these rights or asking you to surrender the rights. Therefore, you have
-certain responsibilities if you distribute copies of the software, or if
-you modify it: responsibilities to respect the freedom of others.
-
- For example, if you distribute copies of such a program, whether
-gratis or for a fee, you must pass on to the recipients the same
-freedoms that you received. You must make sure that they, too, receive
-or can get the source code. And you must show them these terms so they
-know their rights.
-
- Developers that use the GNU GPL protect your rights with two steps:
-(1) assert copyright on the software, and (2) offer you this License
-giving you legal permission to copy, distribute and/or modify it.
-
- For the developers' and authors' protection, the GPL clearly explains
-that there is no warranty for this free software. For both users' and
-authors' sake, the GPL requires that modified versions be marked as
-changed, so that their problems will not be attributed erroneously to
-authors of previous versions.
-
- Some devices are designed to deny users access to install or run
-modified versions of the software inside them, although the manufacturer
-can do so. This is fundamentally incompatible with the aim of
-protecting users' freedom to change the software. The systematic
-pattern of such abuse occurs in the area of products for individuals to
-use, which is precisely where it is most unacceptable. Therefore, we
-have designed this version of the GPL to prohibit the practice for those
-products. If such problems arise substantially in other domains, we
-stand ready to extend this provision to those domains in future versions
-of the GPL, as needed to protect the freedom of users.
-
- Finally, every program is threatened constantly by software patents.
-States should not allow patents to restrict development and use of
-software on general-purpose computers, but in those that do, we wish to
-avoid the special danger that patents applied to a free program could
-make it effectively proprietary. To prevent this, the GPL assures that
-patents cannot be used to render the program non-free.
-
- The precise terms and conditions for copying, distribution and
-modification follow.
-
- TERMS AND CONDITIONS
-
- 0. Definitions.
-
- "This License" refers to version 3 of the GNU General Public License.
-
- "Copyright" also means copyright-like laws that apply to other kinds of
-works, such as semiconductor masks.
-
- "The Program" refers to any copyrightable work licensed under this
-License. Each licensee is addressed as "you". "Licensees" and
-"recipients" may be individuals or organizations.
-
- To "modify" a work means to copy from or adapt all or part of the work
-in a fashion requiring copyright permission, other than the making of an
-exact copy. The resulting work is called a "modified version" of the
-earlier work or a work "based on" the earlier work.
-
- A "covered work" means either the unmodified Program or a work based
-on the Program.
-
- To "propagate" a work means to do anything with it that, without
-permission, would make you directly or secondarily liable for
-infringement under applicable copyright law, except executing it on a
-computer or modifying a private copy. Propagation includes copying,
-distribution (with or without modification), making available to the
-public, and in some countries other activities as well.
-
- To "convey" a work means any kind of propagation that enables other
-parties to make or receive copies. Mere interaction with a user through
-a computer network, with no transfer of a copy, is not conveying.
-
- An interactive user interface displays "Appropriate Legal Notices"
-to the extent that it includes a convenient and prominently visible
-feature that (1) displays an appropriate copyright notice, and (2)
-tells the user that there is no warranty for the work (except to the
-extent that warranties are provided), that licensees may convey the
-work under this License, and how to view a copy of this License. If
-the interface presents a list of user commands or options, such as a
-menu, a prominent item in the list meets this criterion.
-
- 1. Source Code.
-
- The "source code" for a work means the preferred form of the work
-for making modifications to it. "Object code" means any non-source
-form of a work.
-
- A "Standard Interface" means an interface that either is an official
-standard defined by a recognized standards body, or, in the case of
-interfaces specified for a particular programming language, one that
-is widely used among developers working in that language.
-
- The "System Libraries" of an executable work include anything, other
-than the work as a whole, that (a) is included in the normal form of
-packaging a Major Component, but which is not part of that Major
-Component, and (b) serves only to enable use of the work with that
-Major Component, or to implement a Standard Interface for which an
-implementation is available to the public in source code form. A
-"Major Component", in this context, means a major essential component
-(kernel, window system, and so on) of the specific operating system
-(if any) on which the executable work runs, or a compiler used to
-produce the work, or an object code interpreter used to run it.
-
- The "Corresponding Source" for a work in object code form means all
-the source code needed to generate, install, and (for an executable
-work) run the object code and to modify the work, including scripts to
-control those activities. However, it does not include the work's
-System Libraries, or general-purpose tools or generally available free
-programs which are used unmodified in performing those activities but
-which are not part of the work. For example, Corresponding Source
-includes interface definition files associated with source files for
-the work, and the source code for shared libraries and dynamically
-linked subprograms that the work is specifically designed to require,
-such as by intimate data communication or control flow between those
-subprograms and other parts of the work.
-
- The Corresponding Source need not include anything that users
-can regenerate automatically from other parts of the Corresponding
-Source.
-
- The Corresponding Source for a work in source code form is that
-same work.
-
- 2. Basic Permissions.
-
- All rights granted under this License are granted for the term of
-copyright on the Program, and are irrevocable provided the stated
-conditions are met. This License explicitly affirms your unlimited
-permission to run the unmodified Program. The output from running a
-covered work is covered by this License only if the output, given its
-content, constitutes a covered work. This License acknowledges your
-rights of fair use or other equivalent, as provided by copyright law.
-
- You may make, run and propagate covered works that you do not
-convey, without conditions so long as your license otherwise remains
-in force. You may convey covered works to others for the sole purpose
-of having them make modifications exclusively for you, or provide you
-with facilities for running those works, provided that you comply with
-the terms of this License in conveying all material for which you do
-not control copyright. Those thus making or running the covered works
-for you must do so exclusively on your behalf, under your direction
-and control, on terms that prohibit them from making any copies of
-your copyrighted material outside their relationship with you.
-
- Conveying under any other circumstances is permitted solely under
-the conditions stated below. Sublicensing is not allowed; section 10
-makes it unnecessary.
-
- 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
-
- No covered work shall be deemed part of an effective technological
-measure under any applicable law fulfilling obligations under article
-11 of the WIPO copyright treaty adopted on 20 December 1996, or
-similar laws prohibiting or restricting circumvention of such
-measures.
-
- When you convey a covered work, you waive any legal power to forbid
-circumvention of technological measures to the extent such circumvention
-is effected by exercising rights under this License with respect to
-the covered work, and you disclaim any intention to limit operation or
-modification of the work as a means of enforcing, against the work's
-users, your or third parties' legal rights to forbid circumvention of
-technological measures.
-
- 4. Conveying Verbatim Copies.
-
- You may convey verbatim copies of the Program's source code as you
-receive it, in any medium, provided that you conspicuously and
-appropriately publish on each copy an appropriate copyright notice;
-keep intact all notices stating that this License and any
-non-permissive terms added in accord with section 7 apply to the code;
-keep intact all notices of the absence of any warranty; and give all
-recipients a copy of this License along with the Program.
-
- You may charge any price or no price for each copy that you convey,
-and you may offer support or warranty protection for a fee.
-
- 5. Conveying Modified Source Versions.
-
- You may convey a work based on the Program, or the modifications to
-produce it from the Program, in the form of source code under the
-terms of section 4, provided that you also meet all of these conditions:
-
- a) The work must carry prominent notices stating that you modified
- it, and giving a relevant date.
-
- b) The work must carry prominent notices stating that it is
- released under this License and any conditions added under section
- 7. This requirement modifies the requirement in section 4 to
- "keep intact all notices".
-
- c) You must license the entire work, as a whole, under this
- License to anyone who comes into possession of a copy. This
- License will therefore apply, along with any applicable section 7
- additional terms, to the whole of the work, and all its parts,
- regardless of how they are packaged. This License gives no
- permission to license the work in any other way, but it does not
- invalidate such permission if you have separately received it.
-
- d) If the work has interactive user interfaces, each must display
- Appropriate Legal Notices; however, if the Program has interactive
- interfaces that do not display Appropriate Legal Notices, your
- work need not make them do so.
-
- A compilation of a covered work with other separate and independent
-works, which are not by their nature extensions of the covered work,
-and which are not combined with it such as to form a larger program,
-in or on a volume of a storage or distribution medium, is called an
-"aggregate" if the compilation and its resulting copyright are not
-used to limit the access or legal rights of the compilation's users
-beyond what the individual works permit. Inclusion of a covered work
-in an aggregate does not cause this License to apply to the other
-parts of the aggregate.
-
- 6. Conveying Non-Source Forms.
-
- You may convey a covered work in object code form under the terms
-of sections 4 and 5, provided that you also convey the
-machine-readable Corresponding Source under the terms of this License,
-in one of these ways:
-
- a) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by the
- Corresponding Source fixed on a durable physical medium
- customarily used for software interchange.
-
- b) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by a
- written offer, valid for at least three years and valid for as
- long as you offer spare parts or customer support for that product
- model, to give anyone who possesses the object code either (1) a
- copy of the Corresponding Source for all the software in the
- product that is covered by this License, on a durable physical
- medium customarily used for software interchange, for a price no
- more than your reasonable cost of physically performing this
- conveying of source, or (2) access to copy the
- Corresponding Source from a network server at no charge.
-
- c) Convey individual copies of the object code with a copy of the
- written offer to provide the Corresponding Source. This
- alternative is allowed only occasionally and noncommercially, and
- only if you received the object code with such an offer, in accord
- with subsection 6b.
-
- d) Convey the object code by offering access from a designated
- place (gratis or for a charge), and offer equivalent access to the
- Corresponding Source in the same way through the same place at no
- further charge. You need not require recipients to copy the
- Corresponding Source along with the object code. If the place to
- copy the object code is a network server, the Corresponding Source
- may be on a different server (operated by you or a third party)
- that supports equivalent copying facilities, provided you maintain
- clear directions next to the object code saying where to find the
- Corresponding Source. Regardless of what server hosts the
- Corresponding Source, you remain obligated to ensure that it is
- available for as long as needed to satisfy these requirements.
-
- e) Convey the object code using peer-to-peer transmission, provided
- you inform other peers where the object code and Corresponding
- Source of the work are being offered to the general public at no
- charge under subsection 6d.
-
- A separable portion of the object code, whose source code is excluded
-from the Corresponding Source as a System Library, need not be
-included in conveying the object code work.
-
- A "User Product" is either (1) a "consumer product", which means any
-tangible personal property which is normally used for personal, family,
-or household purposes, or (2) anything designed or sold for incorporation
-into a dwelling. In determining whether a product is a consumer product,
-doubtful cases shall be resolved in favor of coverage. For a particular
-product received by a particular user, "normally used" refers to a
-typical or common use of that class of product, regardless of the status
-of the particular user or of the way in which the particular user
-actually uses, or expects or is expected to use, the product. A product
-is a consumer product regardless of whether the product has substantial
-commercial, industrial or non-consumer uses, unless such uses represent
-the only significant mode of use of the product.
-
- "Installation Information" for a User Product means any methods,
-procedures, authorization keys, or other information required to install
-and execute modified versions of a covered work in that User Product from
-a modified version of its Corresponding Source. The information must
-suffice to ensure that the continued functioning of the modified object
-code is in no case prevented or interfered with solely because
-modification has been made.
-
- If you convey an object code work under this section in, or with, or
-specifically for use in, a User Product, and the conveying occurs as
-part of a transaction in which the right of possession and use of the
-User Product is transferred to the recipient in perpetuity or for a
-fixed term (regardless of how the transaction is characterized), the
-Corresponding Source conveyed under this section must be accompanied
-by the Installation Information. But this requirement does not apply
-if neither you nor any third party retains the ability to install
-modified object code on the User Product (for example, the work has
-been installed in ROM).
-
- The requirement to provide Installation Information does not include a
-requirement to continue to provide support service, warranty, or updates
-for a work that has been modified or installed by the recipient, or for
-the User Product in which it has been modified or installed. Access to a
-network may be denied when the modification itself materially and
-adversely affects the operation of the network or violates the rules and
-protocols for communication across the network.
-
- Corresponding Source conveyed, and Installation Information provided,
-in accord with this section must be in a format that is publicly
-documented (and with an implementation available to the public in
-source code form), and must require no special password or key for
-unpacking, reading or copying.
-
- 7. Additional Terms.
-
- "Additional permissions" are terms that supplement the terms of this
-License by making exceptions from one or more of its conditions.
-Additional permissions that are applicable to the entire Program shall
-be treated as though they were included in this License, to the extent
-that they are valid under applicable law. If additional permissions
-apply only to part of the Program, that part may be used separately
-under those permissions, but the entire Program remains governed by
-this License without regard to the additional permissions.
-
- When you convey a copy of a covered work, you may at your option
-remove any additional permissions from that copy, or from any part of
-it. (Additional permissions may be written to require their own
-removal in certain cases when you modify the work.) You may place
-additional permissions on material, added by you to a covered work,
-for which you have or can give appropriate copyright permission.
-
- Notwithstanding any other provision of this License, for material you
-add to a covered work, you may (if authorized by the copyright holders of
-that material) supplement the terms of this License with terms:
-
- a) Disclaiming warranty or limiting liability differently from the
- terms of sections 15 and 16 of this License; or
-
- b) Requiring preservation of specified reasonable legal notices or
- author attributions in that material or in the Appropriate Legal
- Notices displayed by works containing it; or
-
- c) Prohibiting misrepresentation of the origin of that material, or
- requiring that modified versions of such material be marked in
- reasonable ways as different from the original version; or
-
- d) Limiting the use for publicity purposes of names of licensors or
- authors of the material; or
-
- e) Declining to grant rights under trademark law for use of some
- trade names, trademarks, or service marks; or
-
- f) Requiring indemnification of licensors and authors of that
- material by anyone who conveys the material (or modified versions of
- it) with contractual assumptions of liability to the recipient, for
- any liability that these contractual assumptions directly impose on
- those licensors and authors.
-
- All other non-permissive additional terms are considered "further
-restrictions" within the meaning of section 10. If the Program as you
-received it, or any part of it, contains a notice stating that it is
-governed by this License along with a term that is a further
-restriction, you may remove that term. If a license document contains
-a further restriction but permits relicensing or conveying under this
-License, you may add to a covered work material governed by the terms
-of that license document, provided that the further restriction does
-not survive such relicensing or conveying.
-
- If you add terms to a covered work in accord with this section, you
-must place, in the relevant source files, a statement of the
-additional terms that apply to those files, or a notice indicating
-where to find the applicable terms.
-
- Additional terms, permissive or non-permissive, may be stated in the
-form of a separately written license, or stated as exceptions;
-the above requirements apply either way.
-
- 8. Termination.
-
- You may not propagate or modify a covered work except as expressly
-provided under this License. Any attempt otherwise to propagate or
-modify it is void, and will automatically terminate your rights under
-this License (including any patent licenses granted under the third
-paragraph of section 11).
-
- However, if you cease all violation of this License, then your
-license from a particular copyright holder is reinstated (a)
-provisionally, unless and until the copyright holder explicitly and
-finally terminates your license, and (b) permanently, if the copyright
-holder fails to notify you of the violation by some reasonable means
-prior to 60 days after the cessation.
-
- Moreover, your license from a particular copyright holder is
-reinstated permanently if the copyright holder notifies you of the
-violation by some reasonable means, this is the first time you have
-received notice of violation of this License (for any work) from that
-copyright holder, and you cure the violation prior to 30 days after
-your receipt of the notice.
-
- Termination of your rights under this section does not terminate the
-licenses of parties who have received copies or rights from you under
-this License. If your rights have been terminated and not permanently
-reinstated, you do not qualify to receive new licenses for the same
-material under section 10.
-
- 9. Acceptance Not Required for Having Copies.
-
- You are not required to accept this License in order to receive or
-run a copy of the Program. Ancillary propagation of a covered work
-occurring solely as a consequence of using peer-to-peer transmission
-to receive a copy likewise does not require acceptance. However,
-nothing other than this License grants you permission to propagate or
-modify any covered work. These actions infringe copyright if you do
-not accept this License. Therefore, by modifying or propagating a
-covered work, you indicate your acceptance of this License to do so.
-
- 10. Automatic Licensing of Downstream Recipients.
-
- Each time you convey a covered work, the recipient automatically
-receives a license from the original licensors, to run, modify and
-propagate that work, subject to this License. You are not responsible
-for enforcing compliance by third parties with this License.
-
- An "entity transaction" is a transaction transferring control of an
-organization, or substantially all assets of one, or subdividing an
-organization, or merging organizations. If propagation of a covered
-work results from an entity transaction, each party to that
-transaction who receives a copy of the work also receives whatever
-licenses to the work the party's predecessor in interest had or could
-give under the previous paragraph, plus a right to possession of the
-Corresponding Source of the work from the predecessor in interest, if
-the predecessor has it or can get it with reasonable efforts.
-
- You may not impose any further restrictions on the exercise of the
-rights granted or affirmed under this License. For example, you may
-not impose a license fee, royalty, or other charge for exercise of
-rights granted under this License, and you may not initiate litigation
-(including a cross-claim or counterclaim in a lawsuit) alleging that
-any patent claim is infringed by making, using, selling, offering for
-sale, or importing the Program or any portion of it.
-
- 11. Patents.
-
- A "contributor" is a copyright holder who authorizes use under this
-License of the Program or a work on which the Program is based. The
-work thus licensed is called the contributor's "contributor version".
-
- A contributor's "essential patent claims" are all patent claims
-owned or controlled by the contributor, whether already acquired or
-hereafter acquired, that would be infringed by some manner, permitted
-by this License, of making, using, or selling its contributor version,
-but do not include claims that would be infringed only as a
-consequence of further modification of the contributor version. For
-purposes of this definition, "control" includes the right to grant
-patent sublicenses in a manner consistent with the requirements of
-this License.
-
- Each contributor grants you a non-exclusive, worldwide, royalty-free
-patent license under the contributor's essential patent claims, to
-make, use, sell, offer for sale, import and otherwise run, modify and
-propagate the contents of its contributor version.
-
- In the following three paragraphs, a "patent license" is any express
-agreement or commitment, however denominated, not to enforce a patent
-(such as an express permission to practice a patent or covenant not to
-sue for patent infringement). To "grant" such a patent license to a
-party means to make such an agreement or commitment not to enforce a
-patent against the party.
-
- If you convey a covered work, knowingly relying on a patent license,
-and the Corresponding Source of the work is not available for anyone
-to copy, free of charge and under the terms of this License, through a
-publicly available network server or other readily accessible means,
-then you must either (1) cause the Corresponding Source to be so
-available, or (2) arrange to deprive yourself of the benefit of the
-patent license for this particular work, or (3) arrange, in a manner
-consistent with the requirements of this License, to extend the patent
-license to downstream recipients. "Knowingly relying" means you have
-actual knowledge that, but for the patent license, your conveying the
-covered work in a country, or your recipient's use of the covered work
-in a country, would infringe one or more identifiable patents in that
-country that you have reason to believe are valid.
-
- If, pursuant to or in connection with a single transaction or
-arrangement, you convey, or propagate by procuring conveyance of, a
-covered work, and grant a patent license to some of the parties
-receiving the covered work authorizing them to use, propagate, modify
-or convey a specific copy of the covered work, then the patent license
-you grant is automatically extended to all recipients of the covered
-work and works based on it.
-
- A patent license is "discriminatory" if it does not include within
-the scope of its coverage, prohibits the exercise of, or is
-conditioned on the non-exercise of one or more of the rights that are
-specifically granted under this License. You may not convey a covered
-work if you are a party to an arrangement with a third party that is
-in the business of distributing software, under which you make payment
-to the third party based on the extent of your activity of conveying
-the work, and under which the third party grants, to any of the
-parties who would receive the covered work from you, a discriminatory
-patent license (a) in connection with copies of the covered work
-conveyed by you (or copies made from those copies), or (b) primarily
-for and in connection with specific products or compilations that
-contain the covered work, unless you entered into that arrangement,
-or that patent license was granted, prior to 28 March 2007.
-
- Nothing in this License shall be construed as excluding or limiting
-any implied license or other defenses to infringement that may
-otherwise be available to you under applicable patent law.
-
- 12. No Surrender of Others' Freedom.
-
- If conditions are imposed on you (whether by court order, agreement or
-otherwise) that contradict the conditions of this License, they do not
-excuse you from the conditions of this License. If you cannot convey a
-covered work so as to satisfy simultaneously your obligations under this
-License and any other pertinent obligations, then as a consequence you may
-not convey it at all. For example, if you agree to terms that obligate you
-to collect a royalty for further conveying from those to whom you convey
-the Program, the only way you could satisfy both those terms and this
-License would be to refrain entirely from conveying the Program.
-
- 13. Use with the GNU Affero General Public License.
-
- Notwithstanding any other provision of this License, you have
-permission to link or combine any covered work with a work licensed
-under version 3 of the GNU Affero General Public License into a single
-combined work, and to convey the resulting work. The terms of this
-License will continue to apply to the part which is the covered work,
-but the special requirements of the GNU Affero General Public License,
-section 13, concerning interaction through a network will apply to the
-combination as such.
-
- 14. Revised Versions of this License.
-
- The Free Software Foundation may publish revised and/or new versions of
-the GNU General Public License from time to time. Such new versions will
-be similar in spirit to the present version, but may differ in detail to
-address new problems or concerns.
-
- Each version is given a distinguishing version number. If the
-Program specifies that a certain numbered version of the GNU General
-Public License "or any later version" applies to it, you have the
-option of following the terms and conditions either of that numbered
-version or of any later version published by the Free Software
-Foundation. If the Program does not specify a version number of the
-GNU General Public License, you may choose any version ever published
-by the Free Software Foundation.
-
- If the Program specifies that a proxy can decide which future
-versions of the GNU General Public License can be used, that proxy's
-public statement of acceptance of a version permanently authorizes you
-to choose that version for the Program.
-
- Later license versions may give you additional or different
-permissions. However, no additional obligations are imposed on any
-author or copyright holder as a result of your choosing to follow a
-later version.
-
- 15. Disclaimer of Warranty.
-
- THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
-APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
-HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
-OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
-THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
-IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
-ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
-
- 16. Limitation of Liability.
-
- IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
-WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
-THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
-GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
-USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
-DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
-PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
-EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
-SUCH DAMAGES.
-
- 17. Interpretation of Sections 15 and 16.
-
- If the disclaimer of warranty and limitation of liability provided
-above cannot be given local legal effect according to their terms,
-reviewing courts shall apply local law that most closely approximates
-an absolute waiver of all civil liability in connection with the
-Program, unless a warranty or assumption of liability accompanies a
-copy of the Program in return for a fee.
-
- END OF TERMS AND CONDITIONS
-
- How to Apply These Terms to Your New Programs
-
- If you develop a new program, and you want it to be of the greatest
-possible use to the public, the best way to achieve this is to make it
-free software which everyone can redistribute and change under these terms.
-
- To do so, attach the following notices to the program. It is safest
-to attach them to the start of each source file to most effectively
-state the exclusion of warranty; and each file should have at least
-the "copyright" line and a pointer to where the full notice is found.
-
-
- Copyright (C)
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
-
- You should have received a copy of the GNU General Public License
- along with this program. If not, see .
-
-Also add information on how to contact you by electronic and paper mail.
-
- If the program does terminal interaction, make it output a short
-notice like this when it starts in an interactive mode:
-
- Copyright (C)
- This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
- This is free software, and you are welcome to redistribute it
- under certain conditions; type `show c' for details.
-
-The hypothetical commands `show w' and `show c' should show the appropriate
-parts of the General Public License. Of course, your program's commands
-might be different; for a GUI interface, you would use an "about box".
-
- You should also get your employer (if you work as a programmer) or school,
-if any, to sign a "copyright disclaimer" for the program, if necessary.
-For more information on this, and how to apply and follow the GNU GPL, see
-.
-
- The GNU General Public License does not permit incorporating your program
-into proprietary programs. If your program is a subroutine library, you
-may consider it more useful to permit linking proprietary applications with
-the library. If this is what you want to do, use the GNU Lesser General
-Public License instead of this License. But first, please read
-.
+MIT License
+
+Copyright (c) 2023 Tencent AI Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 8adc9979dec225f8d9d33017159b9ebe1251e601..11ac270bb1189d87dda3577f18fc4aedc986e1c8 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,14 @@
---
-title: Bofan Chatglm Fitness RLHF lora
-emoji: 🍀
-colorFrom: yellow
-colorTo: yellow
+title: ChatGLM2-SadTalker
+emoji: 📺
+colorFrom: purple
+colorTo: green
sdk: gradio
-sdk_version: 3.38.0
+sdk_version: 3.23.0
app_file: app.py
pinned: false
-license: apache-2.0
+license: mit
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/app old.py b/app old.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ad13e89e12a98057380b4c508fb6b2e142c442e
--- /dev/null
+++ b/app old.py
@@ -0,0 +1,608 @@
+import os, sys
+import tempfile
+import gradio as gr
+from src.gradio_demo import SadTalker
+# from src.utils.text2speech import TTSTalker
+from huggingface_hub import snapshot_download
+
+import torch
+import librosa
+from scipy.io.wavfile import write
+from transformers import WavLMModel
+
+import utils
+from models import SynthesizerTrn
+from mel_processing import mel_spectrogram_torch
+from speaker_encoder.voice_encoder import SpeakerEncoder
+
+import time
+from textwrap import dedent
+
+import mdtex2html
+from loguru import logger
+from transformers import AutoModel, AutoTokenizer
+
+from tts_voice import tts_order_voice
+import edge_tts
+import tempfile
+import anyio
+
+
+def get_source_image(image):
+ return image
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+
+def toggle_audio_file(choice):
+ if choice == False:
+ return gr.update(visible=True), gr.update(visible=False)
+ else:
+ return gr.update(visible=False), gr.update(visible=True)
+
+def ref_video_fn(path_of_ref_video):
+ if path_of_ref_video is not None:
+ return gr.update(value=True)
+ else:
+ return gr.update(value=False)
+
+def download_model():
+ REPO_ID = 'vinthony/SadTalker-V002rc'
+ snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
+
+def sadtalker_demo():
+
+ download_model()
+
+ sad_talker = SadTalker(lazy_load=True)
+ # tts_talker = TTSTalker()
+
+download_model()
+sad_talker = SadTalker(lazy_load=True)
+
+
+# ChatGLM2 & FreeVC
+
+'''
+def get_wavlm():
+ os.system('gdown https://drive.google.com/uc?id=12-cB34qCTvByWT-QtOcZaqwwO21FLSqU')
+ shutil.move('WavLM-Large.pt', 'wavlm')
+'''
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
+
+print("Loading FreeVC(24k)...")
+hps = utils.get_hparams_from_file("configs/freevc-24.json")
+freevc_24 = SynthesizerTrn(
+ hps.data.filter_length // 2 + 1,
+ hps.train.segment_size // hps.data.hop_length,
+ **hps.model).to(device)
+_ = freevc_24.eval()
+_ = utils.load_checkpoint("checkpoint/freevc-24.pth", freevc_24, None)
+
+print("Loading WavLM for content...")
+cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
+
+def convert(model, src, tgt):
+ with torch.no_grad():
+ # tgt
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
+ g_tgt = smodel.embed_utterance(wav_tgt)
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
+ else:
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
+ mel_tgt = mel_spectrogram_torch(
+ wav_tgt,
+ hps.data.filter_length,
+ hps.data.n_mel_channels,
+ hps.data.sampling_rate,
+ hps.data.hop_length,
+ hps.data.win_length,
+ hps.data.mel_fmin,
+ hps.data.mel_fmax
+ )
+ # src
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
+ # infer
+ if model == "FreeVC":
+ audio = freevc.infer(c, g=g_tgt)
+ elif model == "FreeVC-s":
+ audio = freevc_s.infer(c, mel=mel_tgt)
+ else:
+ audio = freevc_24.infer(c, g=g_tgt)
+ audio = audio[0][0].data.cpu().float().numpy()
+ if model == "FreeVC" or model == "FreeVC-s":
+ write("out.wav", hps.data.sampling_rate, audio)
+ else:
+ write("out.wav", 24000, audio)
+ out = "out.wav"
+ return out
+
+# GLM2
+
+language_dict = tts_order_voice
+
+# fix timezone in Linux
+os.environ["TZ"] = "Asia/Shanghai"
+try:
+ time.tzset() # type: ignore # pylint: disable=no-member
+except Exception:
+ # Windows
+ logger.warning("Windows, cant run time.tzset()")
+
+# model_name = "THUDM/chatglm2-6b"
+model_name = "THUDM/chatglm2-6b-int4"
+
+RETRY_FLAG = False
+
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
+
+# 4/8 bit
+# model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
+
+has_cuda = torch.cuda.is_available()
+
+# has_cuda = False # force cpu
+
+if has_cuda:
+ model_glm = (
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
+ ) # 3.92G
+else:
+ model_glm = AutoModel.from_pretrained(
+ model_name, trust_remote_code=True
+ ).float() # .float() .half().float()
+
+model_glm = model_glm.eval()
+
+_ = """Override Chatbot.postprocess"""
+
+
+def postprocess(self, y):
+ if y is None:
+ return []
+ for i, (message, response) in enumerate(y):
+ y[i] = (
+ None if message is None else mdtex2html.convert((message)),
+ None if response is None else mdtex2html.convert(response),
+ )
+ return y
+
+
+gr.Chatbot.postprocess = postprocess
+
+
+def parse_text(text):
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
+ lines = text.split("\n")
+ lines = [line for line in lines if line != ""]
+ count = 0
+ for i, line in enumerate(lines):
+ if "```" in line:
+ count += 1
+ items = line.split("`")
+ if count % 2 == 1:
+ lines[i] = f''
+ else:
+ lines[i] = "
"
+ else:
+ if i > 0:
+ if count % 2 == 1:
+ line = line.replace("`", r"\`")
+ line = line.replace("<", "<")
+ line = line.replace(">", ">")
+ line = line.replace(" ", " ")
+ line = line.replace("*", "*")
+ line = line.replace("_", "_")
+ line = line.replace("-", "-")
+ line = line.replace(".", ".")
+ line = line.replace("!", "!")
+ line = line.replace("(", "(")
+ line = line.replace(")", ")")
+ line = line.replace("$", "$")
+ lines[i] = "
" + line
+ text = "".join(lines)
+ return text
+
+
+def predict(
+ RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
+):
+ try:
+ chatbot.append((parse_text(input), ""))
+ except Exception as exc:
+ logger.error(exc)
+ logger.debug(f"{chatbot=}")
+ _ = """
+ if chatbot:
+ chatbot[-1] = (parse_text(input), str(exc))
+ yield chatbot, history, past_key_values
+ # """
+ yield chatbot, history, past_key_values
+
+ for response, history, past_key_values in model_glm.stream_chat(
+ tokenizer,
+ input,
+ history,
+ past_key_values=past_key_values,
+ return_past_key_values=True,
+ max_length=max_length,
+ top_p=top_p,
+ temperature=temperature,
+ ):
+ chatbot[-1] = (parse_text(input), parse_text(response))
+ # chatbot[-1][-1] = parse_text(response)
+
+ yield chatbot, history, past_key_values, parse_text(response)
+
+
+def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
+ if max_length < 10:
+ max_length = 4096
+ if top_p < 0.1 or top_p > 1:
+ top_p = 0.85
+ if temperature <= 0 or temperature > 1:
+ temperature = 0.01
+ try:
+ res, _ = model_glm.chat(
+ tokenizer,
+ input,
+ history=[],
+ past_key_values=None,
+ max_length=max_length,
+ top_p=top_p,
+ temperature=temperature,
+ )
+ # logger.debug(f"{res=} \n{_=}")
+ except Exception as exc:
+ logger.error(f"{exc=}")
+ res = str(exc)
+
+ return res
+
+
+def reset_user_input():
+ return gr.update(value="")
+
+
+def reset_state():
+ return [], [], None, ""
+
+
+# Delete last turn
+def delete_last_turn(chat, history):
+ if chat and history:
+ chat.pop(-1)
+ history.pop(-1)
+ return chat, history
+
+
+# Regenerate response
+def retry_last_answer(
+ user_input, chatbot, max_length, top_p, temperature, history, past_key_values
+):
+ if chatbot and history:
+ # Removing the previous conversation from chat
+ chatbot.pop(-1)
+ # Setting up a flag to capture a retry
+ RETRY_FLAG = True
+ # Getting last message from user
+ user_input = history[-1][0]
+ # Removing bot response from the history
+ history.pop(-1)
+
+ yield from predict(
+ RETRY_FLAG, # type: ignore
+ user_input,
+ chatbot,
+ max_length,
+ top_p,
+ temperature,
+ history,
+ past_key_values,
+ )
+
+# print
+
+def print(text):
+ return text
+
+# TTS
+
+async def text_to_speech_edge(text, language_code):
+ voice = language_dict[language_code]
+ communicate = edge_tts.Communicate(text, voice)
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
+ tmp_path = tmp_file.name
+
+ await communicate.save(tmp_path)
+
+ return tmp_path
+
+
+with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm"), analytics_enabled=False) as demo:
+ gr.HTML(""
+ "📺💕🎶 - ChatGLM2+声音克隆+视频对话:和喜欢的角色畅所欲言吧!
"
+ "")
+ gr.Markdown("## 🥳 - ChatGLM2+FreeVC+SadTalker,为您打造沉浸式的视频对话体验,支持中英双语")
+ gr.Markdown("## 🌊 - 更多精彩应用,尽在[滔滔AI](http://www.talktalkai.com);滔滔AI,为爱滔滔!💕")
+ gr.Markdown("### ⭐ - 如果您喜欢这个程序,欢迎给我的[GitHub项目](https://github.com/KevinWang676/ChatGLM2-Voice-Cloning)点赞支持!")
+
+ with gr.Tab("🍻 - ChatGLM2聊天区"):
+ with gr.Accordion("📒 相关信息", open=False):
+ _ = f""" ChatGLM2的可选参数信息:
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
+ * Top P controls dynamic vocabulary selection based on context.\n
+ 如果您想让ChatGLM2进行角色扮演并与之对话,请先输入恰当的提示词,如“请你扮演成动漫角色蜡笔小新并和我进行对话”;您也可以为ChatGLM2提供自定义的角色设定\n
+ 当您使用声音克隆功能时,请先在此程序的对应位置上传一段您喜欢的音频
+ """
+ gr.Markdown(dedent(_))
+ chatbot = gr.Chatbot(height=300)
+ with gr.Row():
+ with gr.Column(scale=4):
+ with gr.Column(scale=12):
+ user_input = gr.Textbox(
+ label="请在此处和GLM2聊天 (按回车键即可发送)",
+ placeholder="聊点什么吧",
+ )
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
+ with gr.Column(min_width=32, scale=1):
+ with gr.Row():
+ submitBtn = gr.Button("开始和GLM2交流吧", variant="primary")
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
+
+ with gr.Accordion("🔧 更多设置", open=False):
+ with gr.Row():
+ emptyBtn = gr.Button("清空所有聊天记录")
+ max_length = gr.Slider(
+ 0,
+ 32768,
+ value=8192,
+ step=1.0,
+ label="Maximum length",
+ interactive=True,
+ )
+ top_p = gr.Slider(
+ 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
+ )
+ temperature = gr.Slider(
+ 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
+ )
+
+
+ with gr.Row():
+ test1 = gr.Textbox(label="GLM2的最新回答 (可编辑)", lines = 3)
+ with gr.Column():
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
+
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
+
+ with gr.Row():
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
+ audio1 = output_audio
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
+
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
+
+ history = gr.State([])
+ past_key_values = gr.State(None)
+
+ user_input.submit(
+ predict,
+ [
+ RETRY_FLAG,
+ user_input,
+ chatbot,
+ max_length,
+ top_p,
+ temperature,
+ history,
+ past_key_values,
+ ],
+ [chatbot, history, past_key_values, test1],
+ show_progress="full",
+ )
+ submitBtn.click(
+ predict,
+ [
+ RETRY_FLAG,
+ user_input,
+ chatbot,
+ max_length,
+ top_p,
+ temperature,
+ history,
+ past_key_values,
+ ],
+ [chatbot, history, past_key_values, test1],
+ show_progress="full",
+ api_name="predict",
+ )
+ submitBtn.click(reset_user_input, [], [user_input])
+
+ emptyBtn.click(
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
+ )
+
+ retryBtn.click(
+ retry_last_answer,
+ inputs=[
+ user_input,
+ chatbot,
+ max_length,
+ top_p,
+ temperature,
+ history,
+ past_key_values,
+ ],
+ # outputs = [chatbot, history, last_user_message, user_message]
+ outputs=[chatbot, history, past_key_values, test1],
+ )
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
+
+ with gr.Accordion("📔 提示词示例", open=False):
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
+ examples = gr.Examples(
+ examples=[
+ ["Explain the plot of Cinderella in a sentence."],
+ [
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
+ ],
+ ["What are some common mistakes to avoid when writing code?"],
+ ["Build a prompt to generate a beautiful portrait of a horse"],
+ ["Suggest four metaphors to describe the benefits of AI"],
+ ["Write a pop song about leaving home for the sandy beaches."],
+ ["Write a summary demonstrating my ability to tame lions"],
+ ["鲁迅和周树人什么关系"],
+ ["从前有一头牛,这头牛后面有什么?"],
+ ["正无穷大加一大于正无穷大吗?"],
+ ["正无穷大加正无穷大大于正无穷大吗?"],
+ ["-2的平方根等于什么"],
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
+ ["鲁迅和周树人什么关系 用英文回答"],
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
+ [f"{etext} 翻成中文,列出3个版本"],
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
+ ["js 判断一个数是不是质数"],
+ ["js 实现python 的 range(10)"],
+ ["js 实现python 的 [*(range(10)]"],
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
+ ],
+ inputs=[user_input],
+ examples_per_page=30,
+ )
+
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
+ input_text = gr.Text()
+ tr_btn = gr.Button("Go", variant="primary")
+ out_text = gr.Text()
+ tr_btn.click(
+ trans_api,
+ [input_text, max_length, top_p, temperature],
+ out_text,
+ # show_progress="full",
+ api_name="tr",
+ )
+ _ = """
+ input_text.submit(
+ trans_api,
+ [input_text, max_length, top_p, temperature],
+ out_text,
+ show_progress="full",
+ api_name="tr1",
+ )
+ # """
+ with gr.Tab("📺 - 视频聊天区"):
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="sadtalker_source_image"):
+ with gr.TabItem('图片上传'):
+ with gr.Row():
+ source_image = gr.Image(label="请上传一张您喜欢角色的图片", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
+
+
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
+ with gr.TabItem('💡您还可以将视频下载到本地'):
+
+ with gr.Row():
+ driven_audio = audio_cloned
+ driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)
+
+ with gr.Column():
+ use_idle_mode = gr.Checkbox(label="Use Idle Animation", visible=False)
+ length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.", visible=False)
+ use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo
+
+ with gr.Row():
+ ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref", visible=False).style(width=512)
+
+ with gr.Column():
+ use_ref_video = gr.Checkbox(label="Use Reference Video", visible=False)
+ ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))", visible=False)
+
+ ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo
+
+
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
+ with gr.TabItem('视频设置'):
+ with gr.Column(variant='panel'):
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
+ with gr.Row():
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0, visible=False) #
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1, visible=False) #
+ blink_every = gr.Checkbox(label="use eye blink", value=True, visible=False)
+
+ with gr.Row():
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?", visible=False) #
+ preprocess_type = gr.Radio(['crop', 'full'], value='crop', label='是否聚焦角色面部', info="crop:视频会聚焦角色面部;full:视频会显示图片全貌")
+
+ with gr.Row():
+ is_still_mode = gr.Checkbox(label="静态模式 (开启静态模式,角色的面部动作会减少;默认开启)", value=True)
+ facerender = gr.Radio(['facevid2vid','pirender'], value='facevid2vid', label='facerender', info="which face render?", visible=False)
+
+ with gr.Row():
+ batch_size = gr.Slider(label="Batch size (数值越大,生成速度越快;若显卡性能好,可增大数值)", step=1, maximum=32, value=2)
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer", value=True, visible=False)
+
+ submit = gr.Button('开始视频聊天吧', elem_id="sadtalker_generate", variant='primary')
+
+ with gr.Tabs(elem_id="sadtalker_genearted"):
+ gen_video = gr.Video(label="为您生成的专属视频", format="mp4").style(width=256)
+
+
+
+ submit.click(
+ fn=sad_talker.test,
+ inputs=[source_image,
+ driven_audio,
+ preprocess_type,
+ is_still_mode,
+ enhancer,
+ batch_size,
+ size_of_image,
+ pose_style,
+ facerender,
+ exp_weight,
+ use_ref_video,
+ ref_video,
+ ref_info,
+ use_idle_mode,
+ length_of_audio,
+ blink_every
+ ],
+ outputs=[gen_video]
+ )
+ gr.Markdown("### 注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。")
+ gr.Markdown("💡- 如何使用此程序:输入您对ChatGLM的提问后,依次点击“开始和GLM2交流吧”、“生成对应的音频吧”、“开始AI声音克隆吧”、“开始视频聊天吧”四个按键即可;使用声音克隆功能时,请先上传一段您喜欢的音频")
+ gr.HTML('''
+
+ ''')
+
+
+demo.queue().launch(show_error=True, debug=True)
diff --git a/app.py b/app.py
index d6c249caa5da996c0819065003f470d76b5455b4..8bc22c37fb1b175ad289de4d0c1985084f957d3d 100644
--- a/app.py
+++ b/app.py
@@ -1,22 +1,137 @@
-"""Credit to https://github.com/THUDM/ChatGLM2-6B/blob/main/web_demo.py while mistakes are mine."""
-# pylint: disable=broad-exception-caught, redefined-outer-name, missing-function-docstring, missing-module-docstring, too-many-arguments, line-too-long, invalid-name, redefined-builtin, redefined-argument-from-local
-# import gradio as gr
+import os, sys
+import tempfile
+import gradio as gr
+from src.gradio_demo import SadTalker
+# from src.utils.text2speech import TTSTalker
+from huggingface_hub import snapshot_download
-# model_name = "models/THUDM/chatglm2-6b-int4"
-# gr.load(model_name).lauch()
+import torch
+import librosa
+from scipy.io.wavfile import write
+from transformers import WavLMModel
-# %%writefile demo-4bit.py
+import utils
+from models import SynthesizerTrn
+from mel_processing import mel_spectrogram_torch
+from speaker_encoder.voice_encoder import SpeakerEncoder
-import os
import time
from textwrap import dedent
-import gradio as gr
import mdtex2html
-import torch
from loguru import logger
from transformers import AutoModel, AutoTokenizer
+from tts_voice import tts_order_voice
+import edge_tts
+import tempfile
+import anyio
+
+
+def get_source_image(image):
+ return image
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+
+def toggle_audio_file(choice):
+ if choice == False:
+ return gr.update(visible=True), gr.update(visible=False)
+ else:
+ return gr.update(visible=False), gr.update(visible=True)
+
+def ref_video_fn(path_of_ref_video):
+ if path_of_ref_video is not None:
+ return gr.update(value=True)
+ else:
+ return gr.update(value=False)
+
+def download_model():
+ REPO_ID = 'vinthony/SadTalker-V002rc'
+ snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
+
+def sadtalker_demo():
+
+ download_model()
+
+ sad_talker = SadTalker(lazy_load=True)
+ # tts_talker = TTSTalker()
+
+download_model()
+sad_talker = SadTalker(lazy_load=True)
+
+
+# ChatGLM2 & FreeVC
+
+'''
+def get_wavlm():
+ os.system('gdown https://drive.google.com/uc?id=12-cB34qCTvByWT-QtOcZaqwwO21FLSqU')
+ shutil.move('WavLM-Large.pt', 'wavlm')
+'''
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
+
+print("Loading FreeVC(24k)...")
+hps = utils.get_hparams_from_file("configs/freevc-24.json")
+freevc_24 = SynthesizerTrn(
+ hps.data.filter_length // 2 + 1,
+ hps.train.segment_size // hps.data.hop_length,
+ **hps.model).to(device)
+_ = freevc_24.eval()
+_ = utils.load_checkpoint("checkpoint/freevc-24.pth", freevc_24, None)
+
+print("Loading WavLM for content...")
+cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
+
+def convert(model, src, tgt):
+ with torch.no_grad():
+ # tgt
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
+ g_tgt = smodel.embed_utterance(wav_tgt)
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
+ else:
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
+ mel_tgt = mel_spectrogram_torch(
+ wav_tgt,
+ hps.data.filter_length,
+ hps.data.n_mel_channels,
+ hps.data.sampling_rate,
+ hps.data.hop_length,
+ hps.data.win_length,
+ hps.data.mel_fmin,
+ hps.data.mel_fmax
+ )
+ # src
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
+ # infer
+ if model == "FreeVC":
+ audio = freevc.infer(c, g=g_tgt)
+ elif model == "FreeVC-s":
+ audio = freevc_s.infer(c, mel=mel_tgt)
+ else:
+ audio = freevc_24.infer(c, g=g_tgt)
+ audio = audio[0][0].data.cpu().float().numpy()
+ if model == "FreeVC" or model == "FreeVC-s":
+ write("out.wav", hps.data.sampling_rate, audio)
+ else:
+ write("out.wav", 24000, audio)
+ out = "out.wav"
+ return out
+
+# BofanAi
+
+language_dict = tts_order_voice
+
# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
@@ -25,16 +140,32 @@ except Exception:
# Windows
logger.warning("Windows, cant run time.tzset()")
-
-
-model_name = "fb700/chatglm-fitness-RLHF"
+# model_name = "THUDM/chatglm2-6b"
+model_name = "fb700/chatglm-fitness-RLHF"
RETRY_FLAG = False
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
-#model = AutoModel.from_pretrained(model_name, trust_remote_code=True).quantize(4).half().cuda()
-model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda()
-model = model.eval()
+
+# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
+
+# 4/8 bit
+# model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
+
+has_cuda = torch.cuda.is_available()
+
+# has_cuda = False # force cpu
+
+if has_cuda:
+ model_glm = (
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
+ ) # 3.92G
+else:
+ model_glm = AutoModel.from_pretrained(
+ model_name, trust_remote_code=True
+ ).float() # .float() .half().float()
+
+model_glm = model_glm.eval()
_ = """Override Chatbot.postprocess"""
@@ -54,6 +185,7 @@ gr.Chatbot.postprocess = postprocess
def parse_text(text):
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
@@ -99,8 +231,8 @@ def predict(
yield chatbot, history, past_key_values
# """
yield chatbot, history, past_key_values
- """
- for response, history, past_key_values in model.stream_chat(
+
+ for response, history, past_key_values in model_glm.stream_chat(
tokenizer,
input,
history,
@@ -110,23 +242,21 @@ def predict(
top_p=top_p,
temperature=temperature,
):
- """
- for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
- temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
+ # chatbot[-1][-1] = parse_text(response)
- yield chatbot, history, past_key_values
+ yield chatbot, history, past_key_values, parse_text(response)
-def trans_api(input, max_length=40960, top_p=0.8, temperature=0.2):
+def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
if max_length < 10:
- max_length = 40960
+ max_length = 4096
if top_p < 0.1 or top_p > 1:
top_p = 0.85
if temperature <= 0 or temperature > 1:
temperature = 0.01
try:
- res, _ = model.chat(
+ res, _ = model_glm.chat(
tokenizer,
input,
history=[],
@@ -148,7 +278,7 @@ def reset_user_input():
def reset_state():
- return [], [], None
+ return [], [], None, ""
# Delete last turn
@@ -184,131 +314,177 @@ def retry_last_answer(
past_key_values,
)
+# print
-with gr.Blocks(title="Bofan Ai", theme=gr.themes.Soft(text_size="sm")) as demo:
- # gr.HTML("""ChatGLM2-6B-int4
""")
- gr.HTML(
- """It's beyond Fitness,模型由[帛凡]基于ChatGLM-6b进行微调后,在健康(全科)、心理等领域达至少60分的专业水准,而且中文总结能力超越了GPT3.5各版本。"""
- """特别声明:本应用仅为模型能力演示,无任何商业行为,部署资源为Huggingface官方免费提供,任何通过此项目产生的知识仅用于学术参考,作者和网站均不承担任何责任。"""
- """帛凡 Fitness AI 演示
"""
- """Bofan基于chatglm-6的微调模型如果喜欢请给个 ❤ 。遇到任何问题可邮件和我联系👉 fb700@qq.com"""
- )
+def print(text):
+ return text
- with gr.Accordion("🎈 Info", open=False):
- _ = f"""
- ## {model_name}
+# TTS
- ChatGLM-6B 是开源中英双语对话模型,本次训练基于ChatGLM-6B 的第一代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上开展训练。
-
- 本项目经过多位网友实测,中文总结能力超越了GPT3.5各版本,健康咨询水平优于其它同量级模型,且经优化目前可以支持无限context,远大于4k、8K、16K......,可能是任何个人和中小企业首选模型。
-
- *首先,用40万条高质量数据进行强化训练,以提高模型的基础能力;
-
- *第二,使用30万条人类反馈数据,构建一个表达方式规范优雅的语言模式(RM模型);
-
- *第三,在保留SFT阶段三分之一训练数据的同时,增加了30万条fitness数据,叠加RM模型,对ChatGLM-6B进行强化训练。
-
- 通过训练我们对模型有了更深刻的认知,LLM在一直在进化,好的方法和数据可以挖掘出模型的更大潜能。
- 训练中特别强化了中英文学术论文的翻译和总结,可以成为普通用户和科研人员的得力助手。
-
- 免责声明:本应用仅为模型能力演示,无任何商业行为,部署资源为huggingface官方免费提供,任何通过此项目产生的知识仅用于学术参考,作者和网站均不承担任何责任 。
-
- The T4 GPU is sponsored by a community GPU grant from Huggingface. Thanks a lot!
-
- [模型下载地址](https://huggingface.co/fb700/chatglm-fitness-RLHF)
-
-
- """
- gr.Markdown(dedent(_))
- chatbot = gr.Chatbot()
- with gr.Row():
- with gr.Column(scale=4):
- with gr.Column(scale=12):
- user_input = gr.Textbox(
- show_label=False,
- placeholder="请输入内容Input...",
- ).style(container=False)
- RETRY_FLAG = gr.Checkbox(value=False, visible=False)
- with gr.Column(min_width=32, scale=1):
- with gr.Row():
- submitBtn = gr.Button("发送Submit", variant="primary")
- deleteBtn = gr.Button("删除最后一条对话", variant="secondary")
- retryBtn = gr.Button("重新生成Regenerate", variant="secondary")
- with gr.Column(scale=1):
- emptyBtn = gr.Button("清空对话Clear History")
- max_length = gr.Slider(
- 0,
- 32768,
- value=8192,
- step=1.0,
- label="Maximum length",
- interactive=True,
- )
- top_p = gr.Slider(
- 0, 1, value=0.2, step=0.01, label="Top P", interactive=True
- )
- temperature = gr.Slider(
- 0.01, 1, value=0.85, step=0.01, label="Temperature", interactive=True
- )
+async def text_to_speech_edge(text, language_code):
+ voice = language_dict[language_code]
+ communicate = edge_tts.Communicate(text, voice)
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
+ tmp_path = tmp_file.name
- history = gr.State([])
- past_key_values = gr.State(None)
-
- user_input.submit(
- predict,
- [
- RETRY_FLAG,
- user_input,
- chatbot,
- max_length,
- top_p,
- temperature,
- history,
- past_key_values,
- ],
- [chatbot, history, past_key_values],
- show_progress="full",
- )
- submitBtn.click(
- predict,
- [
- RETRY_FLAG,
- user_input,
- chatbot,
- max_length,
- top_p,
- temperature,
- history,
- past_key_values,
- ],
- [chatbot, history, past_key_values],
- show_progress="full",
- api_name="predict",
- )
- submitBtn.click(reset_user_input, [], [user_input])
+ await communicate.save(tmp_path)
- emptyBtn.click(
- reset_state, outputs=[chatbot, history, past_key_values], show_progress="full"
- )
+ return tmp_path
- retryBtn.click(
- retry_last_answer,
- inputs=[
- user_input,
- chatbot,
- max_length,
- top_p,
- temperature,
- history,
- past_key_values,
- ],
- # outputs = [chatbot, history, last_user_message, user_message]
- outputs=[chatbot, history, past_key_values],
- )
- deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
- with gr.Accordion("Example inputs", open=True):
- etext0 = """ "act": "作为基于文本的冒险游戏",\n "prompt": "我想让你扮演一个基于文本的冒险游戏。我在这个基于文本的冒险游戏中扮演一个角色。请尽可能具体地描述角色所看到的内容和环境,并在游戏输出1、2、3让用户选择进行回复,而不是其它方式。我将输入命令来告诉角色该做什么,而你需要回复角色的行动结果以推动游戏的进行。我的第一个命令是'醒来',请从这里开始故事 “ """
+with gr.Blocks(title="Bofan Ai", theme=gr.themes.Soft(text_size="sm"), analytics_enabled=False) as demo:
+ gr.HTML(""
+ "📺💕🎶 - BofanAi+声音克隆+视频对话:和喜欢的角色畅所欲言吧!
"
+ ""
+ """Bofan基于chatglm-6的微调模型如果喜欢请给个 ❤ 。遇到任何问题可邮件和我联系👉 fb700@qq.com"""
+ )
+ gr.Markdown("## 帛凡 Fitness AI 演示"
+ """特别声明:本应用仅为模型能力演示,无任何商业行为,部署资源为Huggingface官方免费提供,任何通过此项目产生的知识仅用于学术参考,作者和网站均不承担任何责任。"""
+ )
+
+ with gr.Tab("🍻 - BofanAi聊天区"):
+ with gr.Accordion("📒 相关信息", open=False):
+ _ = f""" BofanAi的可选参数信息:
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
+ * Top P controls dynamic vocabulary selection based on context.\n
+ 如果您想让BofanAi进行角色扮演并与之对话,请先输入恰当的提示词,如“请你扮演成动漫角色蜡笔小新并和我进行对话”;您也可以为BofanAi提供自定义的角色设定\n
+ 当您使用声音克隆功能时,请先在此程序的对应位置上传一段您喜欢的音频
+ ## {model_name}
+
+ ChatGLM-6B 是开源中英双语对话模型,本次训练基于ChatGLM-6B 的第一代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上开展训练。
+
+ 本项目经过多位网友实测,中文总结能力超越了GPT3.5各版本,健康咨询水平优于其它同量级模型,且经优化目前可以支持无限context,远大于4k、8K、16K......,可能是任何个人和中小企业首选模型。
+
+ *首先,用40万条高质量数据进行强化训练,以提高模型的基础能力;
+
+ *第二,使用30万条人类反馈数据,构建一个表达方式规范优雅的语言模式(RM模型);
+
+ *第三,在保留SFT阶段三分之一训练数据的同时,增加了30万条fitness数据,叠加RM模型,对ChatGLM-6B进行强化训练。
+
+ 通过训练我们对模型有了更深刻的认知,LLM在一直在进化,好的方法和数据可以挖掘出模型的更大潜能。
+ 训练中特别强化了中英文学术论文的翻译和总结,可以成为普通用户和科研人员的得力助手。
+
+ 免责声明:本应用仅为模型能力演示,无任何商业行为,部署资源为huggingface官方免费提供,任何通过此项目产生的知识仅用于学术参考,作者和网站均不承担任何责任 。
+
+ The T4 GPU is sponsored by a community GPU grant from Huggingface. Thanks a lot!
+
+ [模型下载地址](https://huggingface.co/fb700/chatglm-fitness-RLHF)
+ """
+ gr.Markdown(dedent(_))
+ chatbot = gr.Chatbot(height=300)
+ with gr.Row():
+ with gr.Column(scale=4):
+ with gr.Column(scale=12):
+ user_input = gr.Textbox(
+ label="请在此处和BofanAi聊天 (按回车键即可发送)",
+ placeholder="聊点什么吧",
+ )
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
+ with gr.Column(min_width=32, scale=1):
+ with gr.Row():
+ submitBtn = gr.Button("开始和BofanAi交流吧", variant="primary")
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
+
+ with gr.Accordion("🔧 更多设置", open=False):
+ with gr.Row():
+ emptyBtn = gr.Button("清空所有聊天记录")
+ max_length = gr.Slider(
+ 0,
+ 32768,
+ value=8192,
+ step=1.0,
+ label="Maximum length",
+ interactive=True,
+ )
+ top_p = gr.Slider(
+ 0, 1, value=0.2, step=0.01, label="Top P", interactive=True
+ )
+ temperature = gr.Slider(
+ 0.01, 1, value=0.85, step=0.01, label="Temperature", interactive=True
+ )
+
+
+ with gr.Row():
+ test1 = gr.Textbox(label="BofanAi的最新回答 (可编辑)", lines = 3)
+ with gr.Column():
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
+
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
+
+ with gr.Row():
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
+ audio1 = output_audio
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
+
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
+
+ history = gr.State([])
+ past_key_values = gr.State(None)
+
+ user_input.submit(
+ predict,
+ [
+ RETRY_FLAG,
+ user_input,
+ chatbot,
+ max_length,
+ top_p,
+ temperature,
+ history,
+ past_key_values,
+ ],
+ [chatbot, history, past_key_values, test1],
+ show_progress="full",
+ )
+ submitBtn.click(
+ predict,
+ [
+ RETRY_FLAG,
+ user_input,
+ chatbot,
+ max_length,
+ top_p,
+ temperature,
+ history,
+ past_key_values,
+ ],
+ [chatbot, history, past_key_values, test1],
+ show_progress="full",
+ api_name="predict",
+ )
+ submitBtn.click(reset_user_input, [], [user_input])
+
+ emptyBtn.click(
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
+ )
+
+ retryBtn.click(
+ retry_last_answer,
+ inputs=[
+ user_input,
+ chatbot,
+ max_length,
+ top_p,
+ temperature,
+ history,
+ past_key_values,
+ ],
+ # outputs = [chatbot, history, last_user_message, user_message]
+ outputs=[chatbot, history, past_key_values, test1],
+ )
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
+
+ with gr.Accordion("📔 模型应用测试", open=True):
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
+ examples = gr.Examples(
+ examples=[
+ etext0 = """ "act": "作为基于文本的冒险游戏",\n "prompt": "我想让你扮演一个基于文本的冒险游戏。我在这个基于文本的冒险游戏中扮演一个角色。请尽可能具体地描述角色所看到的内容和环境,并在游戏输出1、2、3让用户选择进行回复,而不是其它方式。我将输入命令来告诉角色该做什么,而你需要回复角色的行动结果以推动游戏的进行。我的第一个命令是'醒来',请从这里开始故事 “ """
etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
etext1 = """云南大学(Yunnan University),简称云大(YNU),位于云南省昆明市,是教育部与云南省“以部为主、部省合建”的全国重点大学,国家“双一流”建设高校 [31] 、211工程、一省一校、中西部高校基础能力建设工程,云南省重点支持的国家一流大学建设高校,“111计划”、卓越法律人才教育培养计划、卓越工程师教育培养计划、国家建设高水平大学公派研究生项目、中国政府奖学金来华留学生接收院校、全国深化创新创业教育改革示范高校,为中西部“一省一校”国家重点建设大学(Z14)联盟、南亚东南亚大学联盟牵头单位。 [1]
云南大学始建于1922年,时为私立东陆大学。1930年,改为省立东陆大学。1934年更名为省立云南大学。1938年改为国立云南大学。1946年,《不列颠百科全书》将云南大学列为中国15所在世界最具影响的大学之一。1950年定名为云南大学。1958年,云南大学由中央高教部划归云南省管理。1978年,云南大学被国务院确定为88所全国重点大学之一。1996年首批列入国家“211工程”重点建设大学。1999年,云南政法高等专科学校并入云南大学。 [2] [23]
@@ -370,37 +546,120 @@ with gr.Blocks(title="Bofan Ai", theme=gr.themes.Soft(text_size="sm")) as demo:
["Erkläre die Handlung von Cinderella in einem Satz."],
["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
],
- inputs=[user_input],
- examples_per_page=50,
+ inputs=[user_input],
+ examples_per_page=50,
+ )
+
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
+ input_text = gr.Text()
+ tr_btn = gr.Button("Go", variant="primary")
+ out_text = gr.Text()
+ tr_btn.click(
+ trans_api,
+ [input_text, max_length, top_p, temperature],
+ out_text,
+ # show_progress="full",
+ api_name="tr",
)
-
- with gr.Accordion("For Chat/Translation API", open=False, visible=False):
- input_text = gr.Text()
- tr_btn = gr.Button("Go", variant="primary")
- out_text = gr.Text()
- tr_btn.click(
- trans_api,
- [input_text, max_length, top_p, temperature],
- out_text,
- # show_progress="full",
- api_name="tr",
- )
- _ = """
- input_text.submit(
- trans_api,
- [input_text, max_length, top_p, temperature],
- out_text,
- show_progress="full",
- api_name="tr1",
- )
- # """
-
-# demo.queue().launch(share=False, inbrowser=True)
-# demo.queue().launch(share=True, inbrowser=True, debug=True)
-
-# concurrency_count > 1 requires more memory, max_size: queue size
-# T4 medium: 30GB, model size: ~4G concurrency_count = 6
-# leave one for api access
-# reduce to 5 if OOM occurs to often
-
-demo.queue(concurrency_count=6, max_size=30).launch(debug=True)
+ _ = """
+ input_text.submit(
+ trans_api,
+ [input_text, max_length, top_p, temperature],
+ out_text,
+ show_progress="full",
+ api_name="tr1",
+ )
+ # """
+ with gr.Tab("📺 - 视频聊天区"):
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="sadtalker_source_image"):
+ with gr.TabItem('图片上传'):
+ with gr.Row():
+ source_image = gr.Image(label="请上传一张您喜欢角色的图片", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
+
+
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
+ with gr.TabItem('💡您还可以将视频下载到本地'):
+
+ with gr.Row():
+ driven_audio = audio_cloned
+ driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)
+
+ with gr.Column():
+ use_idle_mode = gr.Checkbox(label="Use Idle Animation", visible=False)
+ length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.", visible=False)
+ use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo
+
+ with gr.Row():
+ ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref", visible=False).style(width=512)
+
+ with gr.Column():
+ use_ref_video = gr.Checkbox(label="Use Reference Video", visible=False)
+ ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))", visible=False)
+
+ ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo
+
+
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
+ with gr.TabItem('视频设置'):
+ with gr.Column(variant='panel'):
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
+ with gr.Row():
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0, visible=False) #
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1, visible=False) #
+ blink_every = gr.Checkbox(label="use eye blink", value=True, visible=False)
+
+ with gr.Row():
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?", visible=False) #
+ preprocess_type = gr.Radio(['crop', 'full'], value='crop', label='是否聚焦角色面部', info="crop:视频会聚焦角色面部;full:视频会显示图片全貌")
+
+ with gr.Row():
+ is_still_mode = gr.Checkbox(label="静态模式 (开启静态模式,角色的面部动作会减少;默认开启)", value=True)
+ facerender = gr.Radio(['facevid2vid','pirender'], value='facevid2vid', label='facerender', info="which face render?", visible=False)
+
+ with gr.Row():
+ batch_size = gr.Slider(label="Batch size (数值越大,生成速度越快;若显卡性能好,可增大数值)", step=1, maximum=32, value=2)
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer", value=True, visible=False)
+
+ submit = gr.Button('开始视频聊天吧', elem_id="sadtalker_generate", variant='primary')
+
+ with gr.Tabs(elem_id="sadtalker_genearted"):
+ gen_video = gr.Video(label="为您生成的专属视频", format="mp4").style(width=256)
+
+
+
+ submit.click(
+ fn=sad_talker.test,
+ inputs=[source_image,
+ driven_audio,
+ preprocess_type,
+ is_still_mode,
+ enhancer,
+ batch_size,
+ size_of_image,
+ pose_style,
+ facerender,
+ exp_weight,
+ use_ref_video,
+ ref_video,
+ ref_info,
+ use_idle_mode,
+ length_of_audio,
+ blink_every
+ ],
+ outputs=[gen_video]
+ )
+ gr.Markdown("### 注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。")
+ gr.Markdown("💡- 如何使用此程序:输入您对ChatGLM的提问后,依次点击“开始和BofanAi交流吧”、“生成对应的音频吧”、“开始AI声音克隆吧”、“开始视频聊天吧”四个按键即可;使用声音克隆功能时,请先上传一段您喜欢的音频")
+ gr.HTML('''
+
+ ''')
+
+
+demo.queue().launch(show_error=True, debug=True)
diff --git a/checkpoint/__init__.py b/checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoint/freevc-24.pth b/checkpoint/freevc-24.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d256c31ffd327d6002980112b891f0db5f7ae849
--- /dev/null
+++ b/checkpoint/freevc-24.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20
+size 472644351
diff --git a/checkpoints/BFM_Fitting/01_MorphableModel.mat b/checkpoints/BFM_Fitting/01_MorphableModel.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/BFM_Fitting/BFM09_model_info.mat b/checkpoints/BFM_Fitting/BFM09_model_info.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/BFM_Fitting/BFM_exp_idx.mat b/checkpoints/BFM_Fitting/BFM_exp_idx.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/BFM_Fitting/BFM_front_idx.mat b/checkpoints/BFM_Fitting/BFM_front_idx.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/BFM_Fitting/facemodel_info.mat b/checkpoints/BFM_Fitting/facemodel_info.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/BFM_Fitting/select_vertex_id.mat b/checkpoints/BFM_Fitting/select_vertex_id.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/BFM_Fitting/similarity_Lm3D_all.mat b/checkpoints/BFM_Fitting/similarity_Lm3D_all.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/BFM_Fitting/std_exp.txt b/checkpoints/BFM_Fitting/std_exp.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/shape_predictor_68_face_landmarks.dat b/checkpoints/shape_predictor_68_face_landmarks.dat
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/commons.py b/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc384912618494475bda9d68fa76530f4fe2a27b
--- /dev/null
+++ b/commons.py
@@ -0,0 +1,171 @@
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size*dilation - dilation)/2)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def intersperse(lst, item):
+ result = [item] * (len(lst) * 2 + 1)
+ result[1::2] = lst
+ return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ """KL(P||Q)"""
+ kl = (logs_q - logs_p) - 0.5
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
+ return kl
+
+
+def rand_gumbel(shape):
+ """Sample from the Gumbel distribution, protect from overflows."""
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+ return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+ return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def rand_spec_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = (
+ math.log(float(max_timescale) / float(min_timescale)) /
+ (num_timescales - 1))
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+ """
+ duration: [b, 1, t_x]
+ mask: [b, 1, t_y, t_x]
+ """
+ device = duration.device
+
+ b, _, t_y, t_x = mask.shape
+ cum_duration = torch.cumsum(duration, -1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+ path = path.view(b, t_x, t_y)
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+ path = path.unsqueeze(1).transpose(2,3) * mask
+ return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1. / norm_type)
+ return total_norm
diff --git a/configs/freevc-24.json b/configs/freevc-24.json
new file mode 100644
index 0000000000000000000000000000000000000000..8454a2500a815ae82ea2db4de1ccd54d8c101887
--- /dev/null
+++ b/configs/freevc-24.json
@@ -0,0 +1,54 @@
+{
+ "train": {
+ "log_interval": 200,
+ "eval_interval": 10000,
+ "seed": 1234,
+ "epochs": 10000,
+ "learning_rate": 2e-4,
+ "betas": [0.8, 0.99],
+ "eps": 1e-9,
+ "batch_size": 64,
+ "fp16_run": false,
+ "lr_decay": 0.999875,
+ "segment_size": 8640,
+ "init_lr_ratio": 1,
+ "warmup_epochs": 0,
+ "c_mel": 45,
+ "c_kl": 1.0,
+ "use_sr": true,
+ "max_speclen": 128,
+ "port": "8008"
+ },
+ "data": {
+ "training_files":"filelists/train.txt",
+ "validation_files":"filelists/val.txt",
+ "max_wav_value": 32768.0,
+ "sampling_rate": 16000,
+ "filter_length": 1280,
+ "hop_length": 320,
+ "win_length": 1280,
+ "n_mel_channels": 80,
+ "mel_fmin": 0.0,
+ "mel_fmax": null
+ },
+ "model": {
+ "inter_channels": 192,
+ "hidden_channels": 192,
+ "filter_channels": 768,
+ "n_heads": 2,
+ "n_layers": 6,
+ "kernel_size": 3,
+ "p_dropout": 0.1,
+ "resblock": "1",
+ "resblock_kernel_sizes": [3,7,11],
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
+ "upsample_rates": [10,6,4,2],
+ "upsample_initial_channel": 512,
+ "upsample_kernel_sizes": [16,16,4,4],
+ "n_layers_q": 3,
+ "use_spectral_norm": false,
+ "gin_channels": 256,
+ "ssl_dim": 1024,
+ "use_spk": true
+ }
+}
diff --git a/mel_processing.py b/mel_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..99c5b35beb83f3b288af0fac5b49ebf2c69f062c
--- /dev/null
+++ b/mel_processing.py
@@ -0,0 +1,112 @@
+import math
+import os
+import random
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.utils.data
+import numpy as np
+import librosa
+import librosa.util as librosa_util
+from librosa.util import normalize, pad_center, tiny
+from scipy.signal import get_window
+from scipy.io.wavfile import read
+from librosa.filters import mel as librosa_mel_fn
+
+MAX_WAV_VALUE = 32768.0
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ global hann_window
+ dtype_device = str(y.dtype) + '_' + str(y.device)
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+ return spec
+
+
+def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
+ global mel_basis
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+ return spec
+
+
+def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ global mel_basis, hann_window
+ dtype_device = str(y.dtype) + '_' + str(y.device)
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..46b8aacb1bef18f6fad4c20c968b19125626799c
--- /dev/null
+++ b/models.py
@@ -0,0 +1,351 @@
+import copy
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+import commons
+import modules
+
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from commons import init_weights, get_padding
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0):
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
+ self.flows.append(modules.Flip())
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ if not reverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow(x, x_mask, g=g, reverse=reverse)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths, g=None):
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+ x = self.pre(x) * x_mask
+ x = self.enc(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+ return z, m, logs, x_mask
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
+ super(Generator, self).__init__()
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(weight_norm(
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
+ k, u, padding=(k-u)//2)))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel//(2**(i+1))
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(resblock(ch, k, d))
+
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+ self.ups.apply(init_weights)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+ def forward(self, x, g=None):
+ x = self.conv_pre(x)
+ if g is not None:
+ x = x + self.cond(g)
+
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i*self.num_kernels+j](x)
+ else:
+ xs += self.resblocks[i*self.num_kernels+j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ self.use_spectral_norm = use_spectral_norm
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ])
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(MultiPeriodDiscriminator, self).__init__()
+ periods = [2,3,5,7,11]
+
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
+ self.discriminators = nn.ModuleList(discs)
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ y_d_gs.append(y_d_g)
+ fmap_rs.append(fmap_r)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class SpeakerEncoder(torch.nn.Module):
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
+ super(SpeakerEncoder, self).__init__()
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
+ self.relu = nn.ReLU()
+
+ def forward(self, mels):
+ self.lstm.flatten_parameters()
+ _, (hidden, _) = self.lstm(mels)
+ embeds_raw = self.relu(self.linear(hidden[-1]))
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
+
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
+ mel_slices = []
+ for i in range(0, total_frames-partial_frames, partial_hop):
+ mel_range = torch.arange(i, i+partial_frames)
+ mel_slices.append(mel_range)
+
+ return mel_slices
+
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
+ mel_len = mel.size(1)
+ last_mel = mel[:,-partial_frames:]
+
+ if mel_len > partial_frames:
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
+ mels = list(mel[:,s] for s in mel_slices)
+ mels.append(last_mel)
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
+
+ with torch.no_grad():
+ partial_embeds = self(mels)
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
+ #embed = embed / torch.linalg.norm(embed, 2)
+ else:
+ with torch.no_grad():
+ embed = self(last_mel)
+
+ return embed
+
+
+class SynthesizerTrn(nn.Module):
+ """
+ Synthesizer for Training
+ """
+
+ def __init__(self,
+ spec_channels,
+ segment_size,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels,
+ ssl_dim,
+ use_spk,
+ **kwargs):
+
+ super().__init__()
+ self.spec_channels = spec_channels
+ self.inter_channels = inter_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.resblock = resblock
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.upsample_rates = upsample_rates
+ self.upsample_initial_channel = upsample_initial_channel
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ self.segment_size = segment_size
+ self.gin_channels = gin_channels
+ self.ssl_dim = ssl_dim
+ self.use_spk = use_spk
+
+ self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16)
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
+
+ if not self.use_spk:
+ self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels)
+
+ def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
+ if c_lengths == None:
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
+ if spec_lengths == None:
+ spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
+
+ if not self.use_spk:
+ g = self.enc_spk(mel.transpose(1,2))
+ g = g.unsqueeze(-1)
+
+ _, m_p, logs_p, _ = self.enc_p(c, c_lengths)
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
+ z_p = self.flow(z, spec_mask, g=g)
+
+ z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
+ o = self.dec(z_slice, g=g)
+
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
+
+ def infer(self, c, g=None, mel=None, c_lengths=None):
+ if c_lengths == None:
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
+ if not self.use_spk:
+ g = self.enc_spk.embed_utterance(mel.transpose(1,2))
+ g = g.unsqueeze(-1)
+
+ z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
+ o = self.dec(z * c_mask, g=g)
+
+ return o
diff --git a/modules.py b/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..52ee14e41a5b6d67d875d1b694aecd2a51244897
--- /dev/null
+++ b/modules.py
@@ -0,0 +1,342 @@
+import copy
+import math
+import numpy as np
+import scipy
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+import commons
+from commons import init_weights, get_padding
+
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(
+ nn.ReLU(),
+ nn.Dropout(p_dropout))
+ for _ in range(n_layers-1):
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class DDSConv(nn.Module):
+ """
+ Dialted and Depth-Separable Convolution
+ """
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
+ super().__init__()
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+
+ self.drop = nn.Dropout(p_dropout)
+ self.convs_sep = nn.ModuleList()
+ self.convs_1x1 = nn.ModuleList()
+ self.norms_1 = nn.ModuleList()
+ self.norms_2 = nn.ModuleList()
+ for i in range(n_layers):
+ dilation = kernel_size ** i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
+ groups=channels, dilation=dilation, padding=padding
+ ))
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+ self.norms_1.append(LayerNorm(channels))
+ self.norms_2.append(LayerNorm(channels))
+
+ def forward(self, x, x_mask, g=None):
+ if g is not None:
+ x = x + g
+ for i in range(self.n_layers):
+ y = self.convs_sep[i](x * x_mask)
+ y = self.norms_1[i](y)
+ y = F.gelu(y)
+ y = self.convs_1x1[i](y)
+ y = self.norms_2[i](y)
+ y = F.gelu(y)
+ y = self.drop(y)
+ x = x + y
+ return x * x_mask
+
+
+class WN(torch.nn.Module):
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
+ super(WN, self).__init__()
+ assert(kernel_size % 2 == 1)
+ self.hidden_channels =hidden_channels
+ self.kernel_size = kernel_size,
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if gin_channels != 0:
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = dilation_rate ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, x_mask, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = commons.fused_add_tanh_sigmoid_multiply(
+ x_in,
+ g_l,
+ n_channels_tensor)
+ acts = self.drop(acts)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
+ x = (x + res_acts) * x_mask
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ if self.gin_channels != 0:
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
+ for l in self.in_layers:
+ torch.nn.utils.remove_weight_norm(l)
+ for l in self.res_skip_layers:
+ torch.nn.utils.remove_weight_norm(l)
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c2(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.convs = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])))
+ ])
+ self.convs.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Log(nn.Module):
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ if not reverse:
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+ else:
+ return x
+
+
+class ElementwiseAffine(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.channels = channels
+ self.m = nn.Parameter(torch.zeros(channels,1))
+ self.logs = nn.Parameter(torch.zeros(channels,1))
+
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1,2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1,2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ return x
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3101f5ec0d503c773ae2fcc863e1594cb689fc69
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,2 @@
+ffmpeg
+libsndfile1
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index c368705dc5e9f407c87457935c299b480498f01a..7e64293d8fe3ab234e1b1c5b09a22f0f63e692d5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,35 @@
+torch
+torchvision
+torchaudio
+numpy==1.22.0
+face_alignment==1.3.0
+imageio==2.19.3
+imageio-ffmpeg==0.4.7
+librosa==0.8.1
+numba
+resampy==0.3.1
+pydub==0.25.1
+scipy
+kornia==0.6.8
+tqdm
+yacs==0.1.8
+pyyaml
+joblib==1.1.0
+scikit-image==0.19.3
+basicsr==1.4.2
+facexlib==0.3.0
+dlib-bin
+gfpgan
+av
+safetensors
+transformers
+webrtcvad==2.0.10
protobuf
-transformers==4.30.2
cpm_kernels
-torch>=2.0
-# gradio
mdtex2html
sentencepiece
accelerate
-loguru
\ No newline at end of file
+loguru
+edge_tts
+altair
+gradio==3.36.1
\ No newline at end of file
diff --git a/speaker_encoder/__init__.py b/speaker_encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/speaker_encoder/__init__.py
@@ -0,0 +1 @@
+
diff --git a/speaker_encoder/audio.py b/speaker_encoder/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcb77ad1d3a85f523e24f84691886736a5686cb
--- /dev/null
+++ b/speaker_encoder/audio.py
@@ -0,0 +1,107 @@
+from scipy.ndimage.morphology import binary_dilation
+from speaker_encoder.params_data import *
+from pathlib import Path
+from typing import Optional, Union
+import numpy as np
+import webrtcvad
+import librosa
+import struct
+
+int16_max = (2 ** 15) - 1
+
+
+def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
+ source_sr: Optional[int] = None):
+ """
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
+
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
+ just .wav), either the waveform as a numpy array of floats.
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
+ this argument will be ignored.
+ """
+ # Load the wav from disk if needed
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
+ wav, source_sr = librosa.load(fpath_or_wav, sr=None)
+ else:
+ wav = fpath_or_wav
+
+ # Resample the wav if needed
+ if source_sr is not None and source_sr != sampling_rate:
+ wav = librosa.resample(wav, source_sr, sampling_rate)
+
+ # Apply the preprocessing: normalize volume and shorten long silences
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
+ wav = trim_long_silences(wav)
+
+ return wav
+
+
+def wav_to_mel_spectrogram(wav):
+ """
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
+ Note: this not a log-mel spectrogram.
+ """
+ frames = librosa.feature.melspectrogram(
+ y=wav,
+ sr=sampling_rate,
+ n_fft=int(sampling_rate * mel_window_length / 1000),
+ hop_length=int(sampling_rate * mel_window_step / 1000),
+ n_mels=mel_n_channels
+ )
+ return frames.astype(np.float32).T
+
+
+def trim_long_silences(wav):
+ """
+ Ensures that segments without voice in the waveform remain no longer than a
+ threshold determined by the VAD parameters in params.py.
+
+ :param wav: the raw waveform as a numpy array of floats
+ :return: the same waveform with silences trimmed away (length <= original wav length)
+ """
+ # Compute the voice detection window size
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
+
+ # Trim the end of the audio to have a multiple of the window size
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
+
+ # Convert the float waveform to 16-bit mono PCM
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
+
+ # Perform voice activation detection
+ voice_flags = []
+ vad = webrtcvad.Vad(mode=3)
+ for window_start in range(0, len(wav), samples_per_window):
+ window_end = window_start + samples_per_window
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
+ sample_rate=sampling_rate))
+ voice_flags = np.array(voice_flags)
+
+ # Smooth the voice detection with a moving average
+ def moving_average(array, width):
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
+ ret = np.cumsum(array_padded, dtype=float)
+ ret[width:] = ret[width:] - ret[:-width]
+ return ret[width - 1:] / width
+
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
+ audio_mask = np.round(audio_mask).astype(np.bool)
+
+ # Dilate the voiced regions
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
+ audio_mask = np.repeat(audio_mask, samples_per_window)
+
+ return wav[audio_mask == True]
+
+
+def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
+ if increase_only and decrease_only:
+ raise ValueError("Both increase only and decrease only are set")
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
+ return wav
+ return wav * (10 ** (dBFS_change / 20))
diff --git a/speaker_encoder/ckpt/__init__.py b/speaker_encoder/ckpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/speaker_encoder/ckpt/__init__.py
@@ -0,0 +1 @@
+
diff --git a/speaker_encoder/ckpt/pretrained_bak_5805000.pt b/speaker_encoder/ckpt/pretrained_bak_5805000.pt
new file mode 100644
index 0000000000000000000000000000000000000000..662d22b686114b4b6124330a688007d9495d22c8
--- /dev/null
+++ b/speaker_encoder/ckpt/pretrained_bak_5805000.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
+size 17090379
diff --git a/speaker_encoder/compute_embed.py b/speaker_encoder/compute_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fee33db0168f40efc42145c06fa62016e3e008e
--- /dev/null
+++ b/speaker_encoder/compute_embed.py
@@ -0,0 +1,40 @@
+from speaker_encoder import inference as encoder
+from multiprocessing.pool import Pool
+from functools import partial
+from pathlib import Path
+# from utils import logmmse
+# from tqdm import tqdm
+# import numpy as np
+# import librosa
+
+
+def embed_utterance(fpaths, encoder_model_fpath):
+ if not encoder.is_loaded():
+ encoder.load_model(encoder_model_fpath)
+
+ # Compute the speaker embedding of the utterance
+ wav_fpath, embed_fpath = fpaths
+ wav = np.load(wav_fpath)
+ wav = encoder.preprocess_wav(wav)
+ embed = encoder.embed_utterance(wav)
+ np.save(embed_fpath, embed, allow_pickle=False)
+
+
+def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
+
+ wav_dir = outdir_root.joinpath("audio")
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
+ assert wav_dir.exists() and metadata_fpath.exists()
+ embed_dir = synthesizer_root.joinpath("embeds")
+ embed_dir.mkdir(exist_ok=True)
+
+ # Gather the input wave filepath and the target output embed filepath
+ with metadata_fpath.open("r") as metadata_file:
+ metadata = [line.split("|") for line in metadata_file]
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
+
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
+ # Embed the utterances in separate threads
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
+ job = Pool(n_processes).imap(func, fpaths)
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
\ No newline at end of file
diff --git a/speaker_encoder/config.py b/speaker_encoder/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c21312f3de971bfa008254c6035cebc09f05e4c
--- /dev/null
+++ b/speaker_encoder/config.py
@@ -0,0 +1,45 @@
+librispeech_datasets = {
+ "train": {
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
+ "other": ["LibriSpeech/train-other-500"]
+ },
+ "test": {
+ "clean": ["LibriSpeech/test-clean"],
+ "other": ["LibriSpeech/test-other"]
+ },
+ "dev": {
+ "clean": ["LibriSpeech/dev-clean"],
+ "other": ["LibriSpeech/dev-other"]
+ },
+}
+libritts_datasets = {
+ "train": {
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
+ "other": ["LibriTTS/train-other-500"]
+ },
+ "test": {
+ "clean": ["LibriTTS/test-clean"],
+ "other": ["LibriTTS/test-other"]
+ },
+ "dev": {
+ "clean": ["LibriTTS/dev-clean"],
+ "other": ["LibriTTS/dev-other"]
+ },
+}
+voxceleb_datasets = {
+ "voxceleb1" : {
+ "train": ["VoxCeleb1/wav"],
+ "test": ["VoxCeleb1/test_wav"]
+ },
+ "voxceleb2" : {
+ "train": ["VoxCeleb2/dev/aac"],
+ "test": ["VoxCeleb2/test_wav"]
+ }
+}
+
+other_datasets = [
+ "LJSpeech-1.1",
+ "VCTK-Corpus/wav48",
+]
+
+anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
diff --git a/speaker_encoder/data_objects/__init__.py b/speaker_encoder/data_objects/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..030317a1d9a328d452bf29bc7a802e29629b1a42
--- /dev/null
+++ b/speaker_encoder/data_objects/__init__.py
@@ -0,0 +1,2 @@
+from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
+from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
diff --git a/speaker_encoder/data_objects/random_cycler.py b/speaker_encoder/data_objects/random_cycler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c405db6b27f46d874d8feb37e3f9c1e12c251109
--- /dev/null
+++ b/speaker_encoder/data_objects/random_cycler.py
@@ -0,0 +1,37 @@
+import random
+
+class RandomCycler:
+ """
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
+ order. For a source sequence of n items and one or several consecutive queries of a total
+ of m items, the following guarantees hold (one implies the other):
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
+ """
+
+ def __init__(self, source):
+ if len(source) == 0:
+ raise Exception("Can't create RandomCycler from an empty collection")
+ self.all_items = list(source)
+ self.next_items = []
+
+ def sample(self, count: int):
+ shuffle = lambda l: random.sample(l, len(l))
+
+ out = []
+ while count > 0:
+ if count >= len(self.all_items):
+ out.extend(shuffle(list(self.all_items)))
+ count -= len(self.all_items)
+ continue
+ n = min(count, len(self.next_items))
+ out.extend(self.next_items[:n])
+ count -= n
+ self.next_items = self.next_items[n:]
+ if len(self.next_items) == 0:
+ self.next_items = shuffle(list(self.all_items))
+ return out
+
+ def __next__(self):
+ return self.sample(1)[0]
+
diff --git a/speaker_encoder/data_objects/speaker.py b/speaker_encoder/data_objects/speaker.py
new file mode 100644
index 0000000000000000000000000000000000000000..07379847a854d85623db02ce5e5409c1566eb80c
--- /dev/null
+++ b/speaker_encoder/data_objects/speaker.py
@@ -0,0 +1,40 @@
+from speaker_encoder.data_objects.random_cycler import RandomCycler
+from speaker_encoder.data_objects.utterance import Utterance
+from pathlib import Path
+
+# Contains the set of utterances of a single speaker
+class Speaker:
+ def __init__(self, root: Path):
+ self.root = root
+ self.name = root.name
+ self.utterances = None
+ self.utterance_cycler = None
+
+ def _load_utterances(self):
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
+ sources = [l.split(",") for l in sources_file]
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
+ self.utterance_cycler = RandomCycler(self.utterances)
+
+ def random_partial(self, count, n_frames):
+ """
+ Samples a batch of unique partial utterances from the disk in a way that all
+ utterances come up at least once every two cycles and in a random order every time.
+
+ :param count: The number of partial utterances to sample from the set of utterances from
+ that speaker. Utterances are guaranteed not to be repeated if is not larger than
+ the number of utterances available.
+ :param n_frames: The number of frames in the partial utterance.
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
+ frames are the frames of the partial utterances and range is the range of the partial
+ utterance with regard to the complete utterance.
+ """
+ if self.utterances is None:
+ self._load_utterances()
+
+ utterances = self.utterance_cycler.sample(count)
+
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
+
+ return a
diff --git a/speaker_encoder/data_objects/speaker_batch.py b/speaker_encoder/data_objects/speaker_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4485605e3ece5b491d1e7d0f223c543b6c91eb96
--- /dev/null
+++ b/speaker_encoder/data_objects/speaker_batch.py
@@ -0,0 +1,12 @@
+import numpy as np
+from typing import List
+from speaker_encoder.data_objects.speaker import Speaker
+
+class SpeakerBatch:
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
+ self.speakers = speakers
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
+
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
diff --git a/speaker_encoder/data_objects/speaker_verification_dataset.py b/speaker_encoder/data_objects/speaker_verification_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cecd8ed8ac100b80d5087fa47f22f92c84fea032
--- /dev/null
+++ b/speaker_encoder/data_objects/speaker_verification_dataset.py
@@ -0,0 +1,56 @@
+from speaker_encoder.data_objects.random_cycler import RandomCycler
+from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
+from speaker_encoder.data_objects.speaker import Speaker
+from speaker_encoder.params_data import partials_n_frames
+from torch.utils.data import Dataset, DataLoader
+from pathlib import Path
+
+# TODO: improve with a pool of speakers for data efficiency
+
+class SpeakerVerificationDataset(Dataset):
+ def __init__(self, datasets_root: Path):
+ self.root = datasets_root
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
+ if len(speaker_dirs) == 0:
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
+ "containing all preprocessed speaker directories.")
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
+ self.speaker_cycler = RandomCycler(self.speakers)
+
+ def __len__(self):
+ return int(1e10)
+
+ def __getitem__(self, index):
+ return next(self.speaker_cycler)
+
+ def get_logs(self):
+ log_string = ""
+ for log_fpath in self.root.glob("*.txt"):
+ with log_fpath.open("r") as log_file:
+ log_string += "".join(log_file.readlines())
+ return log_string
+
+
+class SpeakerVerificationDataLoader(DataLoader):
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
+ worker_init_fn=None):
+ self.utterances_per_speaker = utterances_per_speaker
+
+ super().__init__(
+ dataset=dataset,
+ batch_size=speakers_per_batch,
+ shuffle=False,
+ sampler=sampler,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ collate_fn=self.collate,
+ pin_memory=pin_memory,
+ drop_last=False,
+ timeout=timeout,
+ worker_init_fn=worker_init_fn
+ )
+
+ def collate(self, speakers):
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
+
\ No newline at end of file
diff --git a/speaker_encoder/data_objects/utterance.py b/speaker_encoder/data_objects/utterance.py
new file mode 100644
index 0000000000000000000000000000000000000000..0768c3420f422a7464f305b4c1fb6752c57ceda7
--- /dev/null
+++ b/speaker_encoder/data_objects/utterance.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+
+class Utterance:
+ def __init__(self, frames_fpath, wave_fpath):
+ self.frames_fpath = frames_fpath
+ self.wave_fpath = wave_fpath
+
+ def get_frames(self):
+ return np.load(self.frames_fpath)
+
+ def random_partial(self, n_frames):
+ """
+ Crops the frames into a partial utterance of n_frames
+
+ :param n_frames: The number of frames of the partial utterance
+ :return: the partial utterance frames and a tuple indicating the start and end of the
+ partial utterance in the complete utterance.
+ """
+ frames = self.get_frames()
+ if frames.shape[0] == n_frames:
+ start = 0
+ else:
+ start = np.random.randint(0, frames.shape[0] - n_frames)
+ end = start + n_frames
+ return frames[start:end], (start, end)
\ No newline at end of file
diff --git a/speaker_encoder/hparams.py b/speaker_encoder/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a8c16471903b0c92253b1d70fcd6a61d10e085f
--- /dev/null
+++ b/speaker_encoder/hparams.py
@@ -0,0 +1,31 @@
+## Mel-filterbank
+mel_window_length = 25 # In milliseconds
+mel_window_step = 10 # In milliseconds
+mel_n_channels = 40
+
+
+## Audio
+sampling_rate = 16000
+# Number of spectrogram frames in a partial utterance
+partials_n_frames = 160 # 1600 ms
+
+
+## Voice Activation Detection
+# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
+# This sets the granularity of the VAD. Should not need to be changed.
+vad_window_length = 30 # In milliseconds
+# Number of frames to average together when performing the moving average smoothing.
+# The larger this value, the larger the VAD variations must be to not get smoothed out.
+vad_moving_average_width = 8
+# Maximum number of consecutive silent frames a segment can have.
+vad_max_silence_length = 6
+
+
+## Audio volume normalization
+audio_norm_target_dBFS = -30
+
+
+## Model parameters
+model_hidden_size = 256
+model_embedding_size = 256
+model_num_layers = 3
\ No newline at end of file
diff --git a/speaker_encoder/inference.py b/speaker_encoder/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..15e6bf16ba9e551473cd6b179bb518f0704ac33d
--- /dev/null
+++ b/speaker_encoder/inference.py
@@ -0,0 +1,177 @@
+from speaker_encoder.params_data import *
+from speaker_encoder.model import SpeakerEncoder
+from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
+from matplotlib import cm
+from speaker_encoder import audio
+from pathlib import Path
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+_model = None # type: SpeakerEncoder
+_device = None # type: torch.device
+
+
+def load_model(weights_fpath: Path, device=None):
+ """
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
+ first call to embed_frames() with the default weights file.
+
+ :param weights_fpath: the path to saved model weights.
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
+ If None, will default to your GPU if it"s available, otherwise your CPU.
+ """
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
+ # was saved on. Worth investigating.
+ global _model, _device
+ if device is None:
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ elif isinstance(device, str):
+ _device = torch.device(device)
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
+ checkpoint = torch.load(weights_fpath)
+ _model.load_state_dict(checkpoint["model_state"])
+ _model.eval()
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
+
+
+def is_loaded():
+ return _model is not None
+
+
+def embed_frames_batch(frames_batch):
+ """
+ Computes embeddings for a batch of mel spectrogram.
+
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
+ (batch_size, n_frames, n_channels)
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
+ """
+ if _model is None:
+ raise Exception("Model was not loaded. Call load_model() before inference.")
+
+ frames = torch.from_numpy(frames_batch).to(_device)
+ embed = _model.forward(frames).detach().cpu().numpy()
+ return embed
+
+
+def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
+ min_pad_coverage=0.75, overlap=0.5):
+ """
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
+ partial utterances of each. Both the waveform and the mel
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
+ defined in params_data.py.
+
+ The returned ranges may be indexing further than the length of the waveform. It is
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
+
+ :param n_samples: the number of samples in the waveform
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
+ utterance
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
+ enough frames. If at least of are present,
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
+ utterances are entirely disjoint.
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
+ utterances.
+ """
+ assert 0 <= overlap < 1
+ assert 0 < min_pad_coverage <= 1
+
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
+
+ # Compute the slices
+ wav_slices, mel_slices = [], []
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
+ for i in range(0, steps, frame_step):
+ mel_range = np.array([i, i + partial_utterance_n_frames])
+ wav_range = mel_range * samples_per_frame
+ mel_slices.append(slice(*mel_range))
+ wav_slices.append(slice(*wav_range))
+
+ # Evaluate whether extra padding is warranted or not
+ last_wav_range = wav_slices[-1]
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
+ mel_slices = mel_slices[:-1]
+ wav_slices = wav_slices[:-1]
+
+ return wav_slices, mel_slices
+
+
+def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
+ """
+ Computes an embedding for a single utterance.
+
+ # TODO: handle multiple wavs to benefit from batching on GPU
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
+ :param using_partials: if True, then the utterance is split in partial utterances of
+ frames and the utterance embedding is computed from their
+ normalized average. If False, the utterance is instead computed from feeding the entire
+ spectogram to the network.
+ :param return_partials: if True, the partial embeddings will also be returned along with the
+ wav slices that correspond to the partial embeddings.
+ :param kwargs: additional arguments to compute_partial_splits()
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
+ is True, the partial utterances as a numpy array of float32 of shape
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
+ returned. If is simultaneously set to False, both these values will be None
+ instead.
+ """
+ # Process the entire utterance if not using partials
+ if not using_partials:
+ frames = audio.wav_to_mel_spectrogram(wav)
+ embed = embed_frames_batch(frames[None, ...])[0]
+ if return_partials:
+ return embed, None, None
+ return embed
+
+ # Compute where to split the utterance into partials and pad if necessary
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
+ max_wave_length = wave_slices[-1].stop
+ if max_wave_length >= len(wav):
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
+
+ # Split the utterance into partials
+ frames = audio.wav_to_mel_spectrogram(wav)
+ frames_batch = np.array([frames[s] for s in mel_slices])
+ partial_embeds = embed_frames_batch(frames_batch)
+
+ # Compute the utterance embedding from the partial embeddings
+ raw_embed = np.mean(partial_embeds, axis=0)
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
+
+ if return_partials:
+ return embed, partial_embeds, wave_slices
+ return embed
+
+
+def embed_speaker(wavs, **kwargs):
+ raise NotImplemented()
+
+
+def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
+ if ax is None:
+ ax = plt.gca()
+
+ if shape is None:
+ height = int(np.sqrt(len(embed)))
+ shape = (height, -1)
+ embed = embed.reshape(shape)
+
+ cmap = cm.get_cmap()
+ mappable = ax.imshow(embed, cmap=cmap)
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
+ cbar.set_clim(*color_range)
+
+ ax.set_xticks([]), ax.set_yticks([])
+ ax.set_title(title)
diff --git a/speaker_encoder/model.py b/speaker_encoder/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..aefe6c84cd0de2031daf6b69a942e406594ad187
--- /dev/null
+++ b/speaker_encoder/model.py
@@ -0,0 +1,135 @@
+from speaker_encoder.params_model import *
+from speaker_encoder.params_data import *
+from scipy.interpolate import interp1d
+from sklearn.metrics import roc_curve
+from torch.nn.utils import clip_grad_norm_
+from scipy.optimize import brentq
+from torch import nn
+import numpy as np
+import torch
+
+
+class SpeakerEncoder(nn.Module):
+ def __init__(self, device, loss_device):
+ super().__init__()
+ self.loss_device = loss_device
+
+ # Network defition
+ self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
+ hidden_size=model_hidden_size, # 256
+ num_layers=model_num_layers, # 3
+ batch_first=True).to(device)
+ self.linear = nn.Linear(in_features=model_hidden_size,
+ out_features=model_embedding_size).to(device)
+ self.relu = torch.nn.ReLU().to(device)
+
+ # Cosine similarity scaling (with fixed initial parameter values)
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
+
+ # Loss
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
+
+ def do_gradient_ops(self):
+ # Gradient scale
+ self.similarity_weight.grad *= 0.01
+ self.similarity_bias.grad *= 0.01
+
+ # Gradient clipping
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
+
+ def forward(self, utterances, hidden_init=None):
+ """
+ Computes the embeddings of a batch of utterance spectrograms.
+
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
+ (batch_size, n_frames, n_channels)
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
+ """
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
+ # and the final cell state.
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
+
+ # We take only the hidden state of the last layer
+ embeds_raw = self.relu(self.linear(hidden[-1]))
+
+ # L2-normalize it
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
+
+ return embeds
+
+ def similarity_matrix(self, embeds):
+ """
+ Computes the similarity matrix according the section 2.1 of GE2E.
+
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
+ utterances_per_speaker, embedding_size)
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
+ utterances_per_speaker, speakers_per_batch)
+ """
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
+
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
+
+ # Exclusive centroids (1 per utterance)
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
+ centroids_excl /= (utterances_per_speaker - 1)
+ centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
+
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
+ # We vectorize the computation for efficiency.
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
+ speakers_per_batch).to(self.loss_device)
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
+ for j in range(speakers_per_batch):
+ mask = np.where(mask_matrix[j])[0]
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
+
+ ## Even more vectorized version (slower maybe because of transpose)
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
+ # ).to(self.loss_device)
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
+ # mask = np.where(1 - eye)
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
+ # mask = np.where(eye)
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
+
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
+ return sim_matrix
+
+ def loss(self, embeds):
+ """
+ Computes the softmax loss according the section 2.1 of GE2E.
+
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
+ utterances_per_speaker, embedding_size)
+ :return: the loss and the EER for this batch of embeddings.
+ """
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
+
+ # Loss
+ sim_matrix = self.similarity_matrix(embeds)
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
+ speakers_per_batch))
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
+ loss = self.loss_fn(sim_matrix, target)
+
+ # EER (not backpropagated)
+ with torch.no_grad():
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
+ labels = np.array([inv_argmax(i) for i in ground_truth])
+ preds = sim_matrix.detach().cpu().numpy()
+
+ # Snippet from https://yangcha.github.io/EER-ROC/
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
+
+ return loss, eer
diff --git a/speaker_encoder/params_data.py b/speaker_encoder/params_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdb1716ed45617f2b127a7fb8885afe6cc74fb71
--- /dev/null
+++ b/speaker_encoder/params_data.py
@@ -0,0 +1,29 @@
+
+## Mel-filterbank
+mel_window_length = 25 # In milliseconds
+mel_window_step = 10 # In milliseconds
+mel_n_channels = 40
+
+
+## Audio
+sampling_rate = 16000
+# Number of spectrogram frames in a partial utterance
+partials_n_frames = 160 # 1600 ms
+# Number of spectrogram frames at inference
+inference_n_frames = 80 # 800 ms
+
+
+## Voice Activation Detection
+# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
+# This sets the granularity of the VAD. Should not need to be changed.
+vad_window_length = 30 # In milliseconds
+# Number of frames to average together when performing the moving average smoothing.
+# The larger this value, the larger the VAD variations must be to not get smoothed out.
+vad_moving_average_width = 8
+# Maximum number of consecutive silent frames a segment can have.
+vad_max_silence_length = 6
+
+
+## Audio volume normalization
+audio_norm_target_dBFS = -30
+
diff --git a/speaker_encoder/params_model.py b/speaker_encoder/params_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e356472fb5a27f370cb3920976a11d12a76c1b7
--- /dev/null
+++ b/speaker_encoder/params_model.py
@@ -0,0 +1,11 @@
+
+## Model parameters
+model_hidden_size = 256
+model_embedding_size = 256
+model_num_layers = 3
+
+
+## Training parameters
+learning_rate_init = 1e-4
+speakers_per_batch = 64
+utterances_per_speaker = 10
diff --git a/speaker_encoder/preprocess.py b/speaker_encoder/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5ab25ef7cb4adeb76cad11962f179d6a38edcc
--- /dev/null
+++ b/speaker_encoder/preprocess.py
@@ -0,0 +1,285 @@
+from multiprocess.pool import ThreadPool
+from speaker_encoder.params_data import *
+from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
+from datetime import datetime
+from speaker_encoder import audio
+from pathlib import Path
+from tqdm import tqdm
+import numpy as np
+
+
+class DatasetLog:
+ """
+ Registers metadata about the dataset in a text file.
+ """
+ def __init__(self, root, name):
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
+ self.sample_data = dict()
+
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
+ self.write_line("-----")
+ self._log_params()
+
+ def _log_params(self):
+ from speaker_encoder import params_data
+ self.write_line("Parameter values:")
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
+ value = getattr(params_data, param_name)
+ self.write_line("\t%s: %s" % (param_name, value))
+ self.write_line("-----")
+
+ def write_line(self, line):
+ self.text_file.write("%s\n" % line)
+
+ def add_sample(self, **kwargs):
+ for param_name, value in kwargs.items():
+ if not param_name in self.sample_data:
+ self.sample_data[param_name] = []
+ self.sample_data[param_name].append(value)
+
+ def finalize(self):
+ self.write_line("Statistics:")
+ for param_name, values in self.sample_data.items():
+ self.write_line("\t%s:" % param_name)
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
+ self.write_line("-----")
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
+ self.write_line("Finished on %s" % end_time)
+ self.text_file.close()
+
+
+def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
+ dataset_root = datasets_root.joinpath(dataset_name)
+ if not dataset_root.exists():
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
+ return None, None
+ return dataset_root, DatasetLog(out_dir, dataset_name)
+
+
+def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
+ skip_existing, logger):
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
+
+ # Function to preprocess utterances for one speaker
+ def preprocess_speaker(speaker_dir: Path):
+ # Give a name to the speaker that includes its dataset
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
+
+ # Create an output directory with that name, as well as a txt file containing a
+ # reference to each source file.
+ speaker_out_dir = out_dir.joinpath(speaker_name)
+ speaker_out_dir.mkdir(exist_ok=True)
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
+
+ # There's a possibility that the preprocessing was interrupted earlier, check if
+ # there already is a sources file.
+ if sources_fpath.exists():
+ try:
+ with sources_fpath.open("r") as sources_file:
+ existing_fnames = {line.split(",")[0] for line in sources_file}
+ except:
+ existing_fnames = {}
+ else:
+ existing_fnames = {}
+
+ # Gather all audio files for that speaker recursively
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
+ # Check if the target output file already exists
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
+ if skip_existing and out_fname in existing_fnames:
+ continue
+
+ # Load and preprocess the waveform
+ wav = audio.preprocess_wav(in_fpath)
+ if len(wav) == 0:
+ continue
+
+ # Create the mel spectrogram, discard those that are too short
+ frames = audio.wav_to_mel_spectrogram(wav)
+ if len(frames) < partials_n_frames:
+ continue
+
+ out_fpath = speaker_out_dir.joinpath(out_fname)
+ np.save(out_fpath, frames)
+ logger.add_sample(duration=len(wav) / sampling_rate)
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
+
+ sources_file.close()
+
+ # Process the utterances for each speaker
+ with ThreadPool(8) as pool:
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
+ unit="speakers"))
+ logger.finalize()
+ print("Done preprocessing %s.\n" % dataset_name)
+
+
+# Function to preprocess utterances for one speaker
+def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
+ # Give a name to the speaker that includes its dataset
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
+
+ # Create an output directory with that name, as well as a txt file containing a
+ # reference to each source file.
+ speaker_out_dir = out_dir.joinpath(speaker_name)
+ speaker_out_dir.mkdir(exist_ok=True)
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
+
+ # There's a possibility that the preprocessing was interrupted earlier, check if
+ # there already is a sources file.
+ # if sources_fpath.exists():
+ # try:
+ # with sources_fpath.open("r") as sources_file:
+ # existing_fnames = {line.split(",")[0] for line in sources_file}
+ # except:
+ # existing_fnames = {}
+ # else:
+ # existing_fnames = {}
+ existing_fnames = {}
+ # Gather all audio files for that speaker recursively
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
+
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
+ # Check if the target output file already exists
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
+ if skip_existing and out_fname in existing_fnames:
+ continue
+
+ # Load and preprocess the waveform
+ wav = audio.preprocess_wav(in_fpath)
+ if len(wav) == 0:
+ continue
+
+ # Create the mel spectrogram, discard those that are too short
+ frames = audio.wav_to_mel_spectrogram(wav)
+ if len(frames) < partials_n_frames:
+ continue
+
+ out_fpath = speaker_out_dir.joinpath(out_fname)
+ np.save(out_fpath, frames)
+ # logger.add_sample(duration=len(wav) / sampling_rate)
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
+
+ sources_file.close()
+ return len(wav)
+
+def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
+ skip_existing, logger):
+ # from multiprocessing import Pool, cpu_count
+ from pathos.multiprocessing import ProcessingPool as Pool
+ # Function to preprocess utterances for one speaker
+ def __preprocess_speaker(speaker_dir: Path):
+ # Give a name to the speaker that includes its dataset
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
+
+ # Create an output directory with that name, as well as a txt file containing a
+ # reference to each source file.
+ speaker_out_dir = out_dir.joinpath(speaker_name)
+ speaker_out_dir.mkdir(exist_ok=True)
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
+
+ existing_fnames = {}
+ # Gather all audio files for that speaker recursively
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
+ wav_lens = []
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
+ # Check if the target output file already exists
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
+ if skip_existing and out_fname in existing_fnames:
+ continue
+
+ # Load and preprocess the waveform
+ wav = audio.preprocess_wav(in_fpath)
+ if len(wav) == 0:
+ continue
+
+ # Create the mel spectrogram, discard those that are too short
+ frames = audio.wav_to_mel_spectrogram(wav)
+ if len(frames) < partials_n_frames:
+ continue
+
+ out_fpath = speaker_out_dir.joinpath(out_fname)
+ np.save(out_fpath, frames)
+ # logger.add_sample(duration=len(wav) / sampling_rate)
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
+ wav_lens.append(len(wav))
+ sources_file.close()
+ return wav_lens
+
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
+ # Process the utterances for each speaker
+ # with ThreadPool(8) as pool:
+ # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
+ # unit="speakers"))
+ pool = Pool(processes=20)
+ for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
+ for wav_len in wav_lens:
+ logger.add_sample(duration=wav_len / sampling_rate)
+ print(f'{i}/{len(speaker_dirs)} \r')
+
+ logger.finalize()
+ print("Done preprocessing %s.\n" % dataset_name)
+
+
+def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
+ for dataset_name in librispeech_datasets["train"]["other"]:
+ # Initialize the preprocessing
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
+ if not dataset_root:
+ return
+
+ # Preprocess all speakers
+ speaker_dirs = list(dataset_root.glob("*"))
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
+ skip_existing, logger)
+
+
+def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
+ # Initialize the preprocessing
+ dataset_name = "VoxCeleb1"
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
+ if not dataset_root:
+ return
+
+ # Get the contents of the meta file
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
+ metadata = [line.split("\t") for line in metafile][1:]
+
+ # Select the ID and the nationality, filter out non-anglophone speakers
+ nationalities = {line[0]: line[3] for line in metadata}
+ # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
+ # nationality.lower() in anglophone_nationalites]
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
+ (len(keep_speaker_ids), len(nationalities)))
+
+ # Get the speaker directories for anglophone speakers only
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
+ speaker_dir.name in keep_speaker_ids]
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
+
+ # Preprocess all speakers
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
+ skip_existing, logger)
+
+
+def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
+ # Initialize the preprocessing
+ dataset_name = "VoxCeleb2"
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
+ if not dataset_root:
+ return
+
+ # Get the speaker directories
+ # Preprocess all speakers
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
+ _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
+ skip_existing, logger)
diff --git a/speaker_encoder/train.py b/speaker_encoder/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e9485afbeead6a063b5ef69a85f05757d6c91ff
--- /dev/null
+++ b/speaker_encoder/train.py
@@ -0,0 +1,125 @@
+from speaker_encoder.visualizations import Visualizations
+from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
+from speaker_encoder.params_model import *
+from speaker_encoder.model import SpeakerEncoder
+from utils.profiler import Profiler
+from pathlib import Path
+import torch
+
+def sync(device: torch.device):
+ # FIXME
+ return
+ # For correct profiling (cuda operations are async)
+ if device.type == "cuda":
+ torch.cuda.synchronize(device)
+
+def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
+ no_visdom: bool):
+ # Create a dataset and a dataloader
+ dataset = SpeakerVerificationDataset(clean_data_root)
+ loader = SpeakerVerificationDataLoader(
+ dataset,
+ speakers_per_batch, # 64
+ utterances_per_speaker, # 10
+ num_workers=8,
+ )
+
+ # Setup the device on which to run the forward pass and the loss. These can be different,
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
+ # hyperparameters) faster on the CPU.
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ # FIXME: currently, the gradient is None if loss_device is cuda
+ loss_device = torch.device("cpu")
+
+ # Create the model and the optimizer
+ model = SpeakerEncoder(device, loss_device)
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
+ init_step = 1
+
+ # Configure file path for the model
+ state_fpath = models_dir.joinpath(run_id + ".pt")
+ backup_dir = models_dir.joinpath(run_id + "_backups")
+
+ # Load any existing model
+ if not force_restart:
+ if state_fpath.exists():
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
+ checkpoint = torch.load(state_fpath)
+ init_step = checkpoint["step"]
+ model.load_state_dict(checkpoint["model_state"])
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
+ optimizer.param_groups[0]["lr"] = learning_rate_init
+ else:
+ print("No model \"%s\" found, starting training from scratch." % run_id)
+ else:
+ print("Starting the training from scratch.")
+ model.train()
+
+ # Initialize the visualization environment
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
+ vis.log_dataset(dataset)
+ vis.log_params()
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
+ vis.log_implementation({"Device": device_name})
+
+ # Training loop
+ profiler = Profiler(summarize_every=10, disabled=False)
+ for step, speaker_batch in enumerate(loader, init_step):
+ profiler.tick("Blocking, waiting for batch (threaded)")
+
+ # Forward pass
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
+ sync(device)
+ profiler.tick("Data to %s" % device)
+ embeds = model(inputs)
+ sync(device)
+ profiler.tick("Forward pass")
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
+ loss, eer = model.loss(embeds_loss)
+ sync(loss_device)
+ profiler.tick("Loss")
+
+ # Backward pass
+ model.zero_grad()
+ loss.backward()
+ profiler.tick("Backward pass")
+ model.do_gradient_ops()
+ optimizer.step()
+ profiler.tick("Parameter update")
+
+ # Update visualizations
+ # learning_rate = optimizer.param_groups[0]["lr"]
+ vis.update(loss.item(), eer, step)
+
+ # Draw projections and save them to the backup folder
+ if umap_every != 0 and step % umap_every == 0:
+ print("Drawing and saving projections (step %d)" % step)
+ backup_dir.mkdir(exist_ok=True)
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
+ embeds = embeds.detach().cpu().numpy()
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
+ vis.save()
+
+ # Overwrite the latest version of the model
+ if save_every != 0 and step % save_every == 0:
+ print("Saving the model (step %d)" % step)
+ torch.save({
+ "step": step + 1,
+ "model_state": model.state_dict(),
+ "optimizer_state": optimizer.state_dict(),
+ }, state_fpath)
+
+ # Make a backup
+ if backup_every != 0 and step % backup_every == 0:
+ print("Making a backup (step %d)" % step)
+ backup_dir.mkdir(exist_ok=True)
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
+ torch.save({
+ "step": step + 1,
+ "model_state": model.state_dict(),
+ "optimizer_state": optimizer.state_dict(),
+ }, backup_fpath)
+
+ profiler.tick("Extras (visualizations, saving)")
+
diff --git a/speaker_encoder/visualizations.py b/speaker_encoder/visualizations.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec00fc64d6e9fda2bb8e613531066ac824df1451
--- /dev/null
+++ b/speaker_encoder/visualizations.py
@@ -0,0 +1,178 @@
+from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
+from datetime import datetime
+from time import perf_counter as timer
+import matplotlib.pyplot as plt
+import numpy as np
+# import webbrowser
+import visdom
+import umap
+
+colormap = np.array([
+ [76, 255, 0],
+ [0, 127, 70],
+ [255, 0, 0],
+ [255, 217, 38],
+ [0, 135, 255],
+ [165, 0, 165],
+ [255, 167, 255],
+ [0, 255, 255],
+ [255, 96, 38],
+ [142, 76, 0],
+ [33, 0, 127],
+ [0, 0, 0],
+ [183, 183, 183],
+], dtype=np.float) / 255
+
+
+class Visualizations:
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
+ # Tracking data
+ self.last_update_timestamp = timer()
+ self.update_every = update_every
+ self.step_times = []
+ self.losses = []
+ self.eers = []
+ print("Updating the visualizations every %d steps." % update_every)
+
+ # If visdom is disabled TODO: use a better paradigm for that
+ self.disabled = disabled
+ if self.disabled:
+ return
+
+ # Set the environment name
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
+ if env_name is None:
+ self.env_name = now
+ else:
+ self.env_name = "%s (%s)" % (env_name, now)
+
+ # Connect to visdom and open the corresponding window in the browser
+ try:
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
+ except ConnectionError:
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
+ "start it.")
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
+
+ # Create the windows
+ self.loss_win = None
+ self.eer_win = None
+ # self.lr_win = None
+ self.implementation_win = None
+ self.projection_win = None
+ self.implementation_string = ""
+
+ def log_params(self):
+ if self.disabled:
+ return
+ from speaker_encoder import params_data
+ from speaker_encoder import params_model
+ param_string = "Model parameters:
"
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
+ value = getattr(params_model, param_name)
+ param_string += "\t%s: %s
" % (param_name, value)
+ param_string += "Data parameters:
"
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
+ value = getattr(params_data, param_name)
+ param_string += "\t%s: %s
" % (param_name, value)
+ self.vis.text(param_string, opts={"title": "Parameters"})
+
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
+ if self.disabled:
+ return
+ dataset_string = ""
+ dataset_string += "Speakers: %s\n" % len(dataset.speakers)
+ dataset_string += "\n" + dataset.get_logs()
+ dataset_string = dataset_string.replace("\n", "
")
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
+
+ def log_implementation(self, params):
+ if self.disabled:
+ return
+ implementation_string = ""
+ for param, value in params.items():
+ implementation_string += "%s: %s\n" % (param, value)
+ implementation_string = implementation_string.replace("\n", "
")
+ self.implementation_string = implementation_string
+ self.implementation_win = self.vis.text(
+ implementation_string,
+ opts={"title": "Training implementation"}
+ )
+
+ def update(self, loss, eer, step):
+ # Update the tracking data
+ now = timer()
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
+ self.last_update_timestamp = now
+ self.losses.append(loss)
+ self.eers.append(eer)
+ print(".", end="")
+
+ # Update the plots every steps
+ if step % self.update_every != 0:
+ return
+ time_string = "Step time: mean: %5dms std: %5dms" % \
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
+ if not self.disabled:
+ self.loss_win = self.vis.line(
+ [np.mean(self.losses)],
+ [step],
+ win=self.loss_win,
+ update="append" if self.loss_win else None,
+ opts=dict(
+ legend=["Avg. loss"],
+ xlabel="Step",
+ ylabel="Loss",
+ title="Loss",
+ )
+ )
+ self.eer_win = self.vis.line(
+ [np.mean(self.eers)],
+ [step],
+ win=self.eer_win,
+ update="append" if self.eer_win else None,
+ opts=dict(
+ legend=["Avg. EER"],
+ xlabel="Step",
+ ylabel="EER",
+ title="Equal error rate"
+ )
+ )
+ if self.implementation_win is not None:
+ self.vis.text(
+ self.implementation_string + ("%s" % time_string),
+ win=self.implementation_win,
+ opts={"title": "Training implementation"},
+ )
+
+ # Reset the tracking
+ self.losses.clear()
+ self.eers.clear()
+ self.step_times.clear()
+
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
+ max_speakers=10):
+ max_speakers = min(max_speakers, len(colormap))
+ embeds = embeds[:max_speakers * utterances_per_speaker]
+
+ n_speakers = len(embeds) // utterances_per_speaker
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
+ colors = [colormap[i] for i in ground_truth]
+
+ reducer = umap.UMAP()
+ projected = reducer.fit_transform(embeds)
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
+ plt.gca().set_aspect("equal", "datalim")
+ plt.title("UMAP projection (step %d)" % step)
+ if not self.disabled:
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
+ if out_fpath is not None:
+ plt.savefig(out_fpath)
+ plt.clf()
+
+ def save(self):
+ if not self.disabled:
+ self.vis.save([self.env_name])
+
\ No newline at end of file
diff --git a/speaker_encoder/voice_encoder.py b/speaker_encoder/voice_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..88cdee2de76b72db58c5dd19a888597e0fe12fbb
--- /dev/null
+++ b/speaker_encoder/voice_encoder.py
@@ -0,0 +1,173 @@
+from speaker_encoder.hparams import *
+from speaker_encoder import audio
+from pathlib import Path
+from typing import Union, List
+from torch import nn
+from time import perf_counter as timer
+import numpy as np
+import torch
+
+
+class SpeakerEncoder(nn.Module):
+ def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
+ """
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
+ If None, defaults to cuda if it is available on your machine, otherwise the model will
+ run on cpu. Outputs are always returned on the cpu, as numpy arrays.
+ """
+ super().__init__()
+
+ # Define the network
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
+ self.relu = nn.ReLU()
+
+ # Get the target device
+ if device is None:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ elif isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+
+ # Load the pretrained model'speaker weights
+ # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
+ # if not weights_fpath.exists():
+ # raise Exception("Couldn't find the voice encoder pretrained model at %s." %
+ # weights_fpath)
+
+ start = timer()
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
+
+ self.load_state_dict(checkpoint["model_state"], strict=False)
+ self.to(device)
+
+ if verbose:
+ print("Loaded the voice encoder model on %s in %.2f seconds." %
+ (device.type, timer() - start))
+
+ def forward(self, mels: torch.FloatTensor):
+ """
+ Computes the embeddings of a batch of utterance spectrograms.
+ :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
+ (batch_size, n_frames, n_channels)
+ :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
+ Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
+ """
+ # Pass the input through the LSTM layers and retrieve the final hidden state of the last
+ # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
+ _, (hidden, _) = self.lstm(mels)
+ embeds_raw = self.relu(self.linear(hidden[-1]))
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
+
+ @staticmethod
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
+ """
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to
+ obtain partial utterances of each. Both the waveform and the
+ mel spectrogram slices are returned, so as to make each partial utterance waveform
+ correspond to its spectrogram.
+
+ The returned ranges may be indexing further than the length of the waveform. It is
+ recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
+
+ :param n_samples: the number of samples in the waveform
+ :param rate: how many partial utterances should occur per second. Partial utterances must
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
+ the minimum rate is thus 0.625.
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
+ enough frames. If at least of are present,
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
+ it will be discarded. If there aren't enough frames for one partial utterance,
+ this parameter is ignored so that the function always returns at least one slice.
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
+ utterances.
+ """
+ assert 0 < min_coverage <= 1
+
+ # Compute how many frames separate two partial utterances
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
+ assert 0 < frame_step, "The rate is too high"
+ assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
+ (sampling_rate / (samples_per_frame * partials_n_frames))
+
+ # Compute the slices
+ wav_slices, mel_slices = [], []
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
+ for i in range(0, steps, frame_step):
+ mel_range = np.array([i, i + partials_n_frames])
+ wav_range = mel_range * samples_per_frame
+ mel_slices.append(slice(*mel_range))
+ wav_slices.append(slice(*wav_range))
+
+ # Evaluate whether extra padding is warranted or not
+ last_wav_range = wav_slices[-1]
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
+ if coverage < min_coverage and len(mel_slices) > 1:
+ mel_slices = mel_slices[:-1]
+ wav_slices = wav_slices[:-1]
+
+ return wav_slices, mel_slices
+
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
+ """
+ Computes an embedding for a single utterance. The utterance is divided in partial
+ utterances and an embedding is computed for each. The complete utterance embedding is the
+ L2-normed average embedding of the partial utterances.
+
+ TODO: independent batched version of this function
+
+ :param wav: a preprocessed utterance waveform as a numpy array of float32
+ :param return_partials: if True, the partial embeddings will also be returned along with
+ the wav slices corresponding to each partial utterance.
+ :param rate: how many partial utterances should occur per second. Partial utterances must
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
+ the minimum rate is thus 0.625.
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
+ enough frames. If at least of are present,
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
+ it will be discarded. If there aren't enough frames for one partial utterance,
+ this parameter is ignored so that the function always returns at least one slice.
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
+ is True, the partial utterances as a numpy array of float32 of shape
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
+ returned.
+ """
+ # Compute where to split the utterance into partials and pad the waveform with zeros if
+ # the partial utterances cover a larger range.
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
+ max_wave_length = wav_slices[-1].stop
+ if max_wave_length >= len(wav):
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
+
+ # Split the utterance into partials and forward them through the model
+ mel = audio.wav_to_mel_spectrogram(wav)
+ mels = np.array([mel[s] for s in mel_slices])
+ with torch.no_grad():
+ mels = torch.from_numpy(mels).to(self.device)
+ partial_embeds = self(mels).cpu().numpy()
+
+ # Compute the utterance embedding from the partial embeddings
+ raw_embed = np.mean(partial_embeds, axis=0)
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
+
+ if return_partials:
+ return embed, partial_embeds, wav_slices
+ return embed
+
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
+ """
+ Compute the embedding of a collection of wavs (presumably from the same speaker) by
+ averaging their embedding and L2-normalizing it.
+
+ :param wavs: list of wavs a numpy arrays of float32.
+ :param kwargs: extra arguments to embed_utterance()
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
+ """
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
+ for wav in wavs], axis=0)
+ return raw_embed / np.linalg.norm(raw_embed, 2)
\ No newline at end of file
diff --git a/src/audio2exp_models/audio2exp.py b/src/audio2exp_models/audio2exp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e79a929560592687a505e13188796e2b0ca8772
--- /dev/null
+++ b/src/audio2exp_models/audio2exp.py
@@ -0,0 +1,41 @@
+from tqdm import tqdm
+import torch
+from torch import nn
+
+
+class Audio2Exp(nn.Module):
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
+ super(Audio2Exp, self).__init__()
+ self.cfg = cfg
+ self.device = device
+ self.netG = netG.to(device)
+
+ def test(self, batch):
+
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
+ bs = mel_input.shape[0]
+ T = mel_input.shape[1]
+
+ exp_coeff_pred = []
+
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
+
+ current_mel_input = mel_input[:,i:i+10]
+
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
+ ref = batch['ref'][:, :, :64][:, i:i+10]
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
+
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
+
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
+
+ exp_coeff_pred += [curr_exp_coeff_pred]
+
+ # BS x T x 64
+ results_dict = {
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
+ }
+ return results_dict
+
+
diff --git a/src/audio2exp_models/networks.py b/src/audio2exp_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f052e18101f5446a527ae354b3621e7d0d4991cc
--- /dev/null
+++ b/src/audio2exp_models/networks.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+ self.use_act = use_act
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+
+ if self.use_act:
+ return self.act(out)
+ else:
+ return out
+
+class SimpleWrapperV2(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
+ )
+
+ #### load the pre-trained audio_encoder
+ #self.audio_encoder = self.audio_encoder.to(device)
+ '''
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
+ state_dict = self.audio_encoder.state_dict()
+
+ for k,v in wav2lip_state_dict.items():
+ if 'audio_encoder' in k:
+ print('init:', k)
+ state_dict[k.replace('module.audio_encoder.', '')] = v
+ self.audio_encoder.load_state_dict(state_dict)
+ '''
+
+ self.mapping1 = nn.Linear(512+64+1, 64)
+ #self.mapping2 = nn.Linear(30, 64)
+ #nn.init.constant_(self.mapping1.weight, 0.)
+ nn.init.constant_(self.mapping1.bias, 0.)
+
+ def forward(self, x, ref, ratio):
+ x = self.audio_encoder(x).view(x.size(0), -1)
+ ref_reshape = ref.reshape(x.size(0), -1)
+ ratio = ratio.reshape(x.size(0), -1)
+
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
+ return out
diff --git a/src/audio2pose_models/audio2pose.py b/src/audio2pose_models/audio2pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8cd1427038460a7679260a424d2f01d2bcf2c5
--- /dev/null
+++ b/src/audio2pose_models/audio2pose.py
@@ -0,0 +1,94 @@
+import torch
+from torch import nn
+from src.audio2pose_models.cvae import CVAE
+from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
+from src.audio2pose_models.audio_encoder import AudioEncoder
+
+class Audio2Pose(nn.Module):
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
+ super().__init__()
+ self.cfg = cfg
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
+ self.device = device
+
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
+ self.audio_encoder.eval()
+ for param in self.audio_encoder.parameters():
+ param.requires_grad = False
+
+ self.netG = CVAE(cfg)
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
+
+
+ def forward(self, x):
+
+ batch = {}
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
+
+ # forward
+ audio_emb_list = []
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
+ batch['audio_emb'] = audio_emb
+ batch = self.netG(batch)
+
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
+
+ batch['pose_pred'] = pose_pred
+ batch['pose_gt'] = pose_gt
+
+ return batch
+
+ def test(self, x):
+
+ batch = {}
+ ref = x['ref'] #bs 1 70
+ batch['ref'] = x['ref'][:,0,-6:]
+ batch['class'] = x['class']
+ bs = ref.shape[0]
+
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
+ num_frames = x['num_frames']
+ num_frames = int(num_frames) - 1
+
+ #
+ div = num_frames//self.seq_len
+ re = num_frames%self.seq_len
+ audio_emb_list = []
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
+ device=batch['ref'].device)]
+
+ for i in range(div):
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
+ batch['z'] = z
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
+ batch['audio_emb'] = audio_emb
+ batch = self.netG.test(batch)
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
+
+ if re != 0:
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
+ batch['z'] = z
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
+ if audio_emb.shape[1] != self.seq_len:
+ pad_dim = self.seq_len-audio_emb.shape[1]
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
+ batch['audio_emb'] = audio_emb
+ batch = self.netG.test(batch)
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
+
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
+ batch['pose_motion_pred'] = pose_motion_pred
+
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
+
+ batch['pose_pred'] = pose_pred
+ return batch
diff --git a/src/audio2pose_models/audio_encoder.py b/src/audio2pose_models/audio_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6279d2014a2e786a6c549f084339e18d00e50331
--- /dev/null
+++ b/src/audio2pose_models/audio_encoder.py
@@ -0,0 +1,64 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+ return self.act(out)
+
+class AudioEncoder(nn.Module):
+ def __init__(self, wav2lip_checkpoint, device):
+ super(AudioEncoder, self).__init__()
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
+ # state_dict = self.audio_encoder.state_dict()
+
+ # for k,v in wav2lip_state_dict.items():
+ # if 'audio_encoder' in k:
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
+ # self.audio_encoder.load_state_dict(state_dict)
+
+
+ def forward(self, audio_sequences):
+ # audio_sequences = (B, T, 1, 80, 16)
+ B = audio_sequences.size(0)
+
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
+
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+ dim = audio_embedding.shape[1]
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
+
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
diff --git a/src/audio2pose_models/cvae.py b/src/audio2pose_models/cvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..d017ce865a03bae40dfe066dbcd82e29839d89dc
--- /dev/null
+++ b/src/audio2pose_models/cvae.py
@@ -0,0 +1,149 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+from src.audio2pose_models.res_unet import ResUnet
+
+def class2onehot(idx, class_num):
+
+ assert torch.max(idx).item() < class_num
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
+ onehot.scatter_(1, idx, 1)
+ return onehot
+
+class CVAE(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
+ num_classes = cfg.DATASET.NUM_CLASSES
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
+
+ self.latent_size = latent_size
+
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len)
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len)
+ def reparameterize(self, mu, logvar):
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return mu + eps * std
+
+ def forward(self, batch):
+ batch = self.encoder(batch)
+ mu = batch['mu']
+ logvar = batch['logvar']
+ z = self.reparameterize(mu, logvar)
+ batch['z'] = z
+ return self.decoder(batch)
+
+ def test(self, batch):
+ '''
+ class_id = batch['class']
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
+ batch['z'] = z
+ '''
+ return self.decoder(batch)
+
+class ENCODER(nn.Module):
+ def __init__(self, layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len):
+ super().__init__()
+
+ self.resunet = ResUnet()
+ self.num_classes = num_classes
+ self.seq_len = seq_len
+
+ self.MLP = nn.Sequential()
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
+ self.MLP.add_module(
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+ def forward(self, batch):
+ class_id = batch['class']
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
+ ref = batch['ref'] #bs 6
+ bs = pose_motion_gt.shape[0]
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
+
+ #pose encode
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
+
+ #audio mapping
+ print(audio_in.shape)
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
+ audio_out = audio_out.reshape(bs, -1)
+
+ class_bias = self.classbias[class_id] #bs latent_size
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
+ x_out = self.MLP(x_in)
+
+ mu = self.linear_means(x_out)
+ logvar = self.linear_means(x_out) #bs latent_size
+
+ batch.update({'mu':mu, 'logvar':logvar})
+ return batch
+
+class DECODER(nn.Module):
+ def __init__(self, layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len):
+ super().__init__()
+
+ self.resunet = ResUnet()
+ self.num_classes = num_classes
+ self.seq_len = seq_len
+
+ self.MLP = nn.Sequential()
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
+ self.MLP.add_module(
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+ if i+1 < len(layer_sizes):
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+ else:
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
+
+ self.pose_linear = nn.Linear(6, 6)
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+ def forward(self, batch):
+
+ z = batch['z'] #bs latent_size
+ bs = z.shape[0]
+ class_id = batch['class']
+ ref = batch['ref'] #bs 6
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
+ #print('audio_in: ', audio_in[:, :, :10])
+
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
+ #print('audio_out: ', audio_out[:, :, :10])
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
+ class_bias = self.classbias[class_id] #bs latent_size
+
+ z = z + class_bias
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
+ x_out = x_out.reshape((bs, self.seq_len, -1))
+
+ #print('x_out: ', x_out)
+
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
+
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
+
+ batch.update({'pose_motion_pred':pose_motion_pred})
+ return batch
diff --git a/src/audio2pose_models/discriminator.py b/src/audio2pose_models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..339c38e4812ff38a810f0f3a1c01812f6d5d78db
--- /dev/null
+++ b/src/audio2pose_models/discriminator.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class ConvNormRelu(nn.Module):
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
+ super().__init__()
+ if kernel_size is None:
+ if downsample:
+ kernel_size, stride, padding = 4, 2, 1
+ else:
+ kernel_size, stride, padding = 3, 1, 1
+
+ if conv_type == '2d':
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=False,
+ )
+ if norm == 'BN':
+ self.norm = nn.BatchNorm2d(out_channels)
+ elif norm == 'IN':
+ self.norm = nn.InstanceNorm2d(out_channels)
+ else:
+ raise NotImplementedError
+ elif conv_type == '1d':
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=False,
+ )
+ if norm == 'BN':
+ self.norm = nn.BatchNorm1d(out_channels)
+ elif norm == 'IN':
+ self.norm = nn.InstanceNorm1d(out_channels)
+ else:
+ raise NotImplementedError
+ nn.init.kaiming_normal_(self.conv.weight)
+
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ if isinstance(self.norm, nn.InstanceNorm1d):
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
+ else:
+ x = self.norm(x)
+ x = self.act(x)
+ return x
+
+
+class PoseSequenceDiscriminator(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
+
+ self.seq = nn.Sequential(
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
+ )
+
+ def forward(self, x):
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
+ x = self.seq(x)
+ x = x.squeeze(1)
+ return x
\ No newline at end of file
diff --git a/src/audio2pose_models/networks.py b/src/audio2pose_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aa0b1390e7b4bb0e16057ac94d2fe84f48421af
--- /dev/null
+++ b/src/audio2pose_models/networks.py
@@ -0,0 +1,140 @@
+import torch.nn as nn
+import torch
+
+
+class ResidualConv(nn.Module):
+ def __init__(self, input_dim, output_dim, stride, padding):
+ super(ResidualConv, self).__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.BatchNorm2d(input_dim),
+ nn.ReLU(),
+ nn.Conv2d(
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
+ ),
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+ self.conv_skip = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
+ nn.BatchNorm2d(output_dim),
+ )
+
+ def forward(self, x):
+
+ return self.conv_block(x) + self.conv_skip(x)
+
+
+class Upsample(nn.Module):
+ def __init__(self, input_dim, output_dim, kernel, stride):
+ super(Upsample, self).__init__()
+
+ self.upsample = nn.ConvTranspose2d(
+ input_dim, output_dim, kernel_size=kernel, stride=stride
+ )
+
+ def forward(self, x):
+ return self.upsample(x)
+
+
+class Squeeze_Excite_Block(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(Squeeze_Excite_Block, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y.expand_as(x)
+
+
+class ASPP(nn.Module):
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
+ super(ASPP, self).__init__()
+
+ self.aspp_block1 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+ self.aspp_block2 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+ self.aspp_block3 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
+ self._init_weights()
+
+ def forward(self, x):
+ x1 = self.aspp_block1(x)
+ x2 = self.aspp_block2(x)
+ x3 = self.aspp_block3(x)
+ out = torch.cat([x1, x2, x3], dim=1)
+ return self.output(out)
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+
+class Upsample_(nn.Module):
+ def __init__(self, scale=2):
+ super(Upsample_, self).__init__()
+
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
+
+ def forward(self, x):
+ return self.upsample(x)
+
+
+class AttentionBlock(nn.Module):
+ def __init__(self, input_encoder, input_decoder, output_dim):
+ super(AttentionBlock, self).__init__()
+
+ self.conv_encoder = nn.Sequential(
+ nn.BatchNorm2d(input_encoder),
+ nn.ReLU(),
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
+ nn.MaxPool2d(2, 2),
+ )
+
+ self.conv_decoder = nn.Sequential(
+ nn.BatchNorm2d(input_decoder),
+ nn.ReLU(),
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
+ )
+
+ self.conv_attn = nn.Sequential(
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, 1, 1),
+ )
+
+ def forward(self, x1, x2):
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
+ out = self.conv_attn(out)
+ return out * x2
\ No newline at end of file
diff --git a/src/audio2pose_models/res_unet.py b/src/audio2pose_models/res_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2611e1d1a9bf233507427b34928fca60e094224
--- /dev/null
+++ b/src/audio2pose_models/res_unet.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from src.audio2pose_models.networks import ResidualConv, Upsample
+
+
+class ResUnet(nn.Module):
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
+ super(ResUnet, self).__init__()
+
+ self.input_layer = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
+ nn.BatchNorm2d(filters[0]),
+ nn.ReLU(),
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
+ )
+ self.input_skip = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
+ )
+
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
+
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
+
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
+
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
+
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
+
+ self.output_layer = nn.Sequential(
+ nn.Conv2d(filters[0], 1, 1, 1),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ # Encode
+ x1 = self.input_layer(x) + self.input_skip(x)
+ x2 = self.residual_conv_1(x1)
+ x3 = self.residual_conv_2(x2)
+ # Bridge
+ x4 = self.bridge(x3)
+
+ # Decode
+ x4 = self.upsample_1(x4)
+ x5 = torch.cat([x4, x3], dim=1)
+
+ x6 = self.up_residual_conv1(x5)
+
+ x6 = self.upsample_2(x6)
+ x7 = torch.cat([x6, x2], dim=1)
+
+ x8 = self.up_residual_conv2(x7)
+
+ x8 = self.upsample_3(x8)
+ x9 = torch.cat([x8, x1], dim=1)
+
+ x10 = self.up_residual_conv3(x9)
+
+ output = self.output_layer(x10)
+
+ return output
\ No newline at end of file
diff --git a/src/config/auido2exp.yaml b/src/config/auido2exp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7369dbf350476e14a1d600507f1f8b7d8aa6ecd3
--- /dev/null
+++ b/src/config/auido2exp.yaml
@@ -0,0 +1,58 @@
+DATASET:
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
+ TRAIN_BATCH_SIZE: 32
+ EVAL_BATCH_SIZE: 32
+ EXP: True
+ EXP_DIM: 64
+ FRAME_LEN: 32
+ COEFF_LEN: 73
+ NUM_CLASSES: 46
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+ DEBUG: True
+ NUM_REPEATS: 2
+ T: 40
+
+
+MODEL:
+ FRAMEWORK: V2
+ AUDIOENCODER:
+ LEAKY_RELU: True
+ NORM: 'IN'
+ DISCRIMINATOR:
+ LEAKY_RELU: False
+ INPUT_CHANNELS: 6
+ CVAE:
+ AUDIO_EMB_IN_SIZE: 512
+ AUDIO_EMB_OUT_SIZE: 128
+ SEQ_LEN: 32
+ LATENT_SIZE: 256
+ ENCODER_LAYER_SIZES: [192, 1024]
+ DECODER_LAYER_SIZES: [1024, 192]
+
+
+TRAIN:
+ MAX_EPOCH: 300
+ GENERATOR:
+ LR: 2.0e-5
+ DISCRIMINATOR:
+ LR: 1.0e-5
+ LOSS:
+ W_FEAT: 0
+ W_COEFF_EXP: 2
+ W_LM: 1.0e-2
+ W_LM_MOUTH: 0
+ W_REG: 0
+ W_SYNC: 0
+ W_COLOR: 0
+ W_EXPRESSION: 0
+ W_LIPREADING: 0.01
+ W_LIPREADING_VV: 0
+ W_EYE_BLINK: 4
+
+TAG:
+ NAME: small_dataset
+
+
diff --git a/src/config/auido2pose.yaml b/src/config/auido2pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bc61f94d12f406f2d8d02545e55b61075051484d
--- /dev/null
+++ b/src/config/auido2pose.yaml
@@ -0,0 +1,49 @@
+DATASET:
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
+ TRAIN_BATCH_SIZE: 64
+ EVAL_BATCH_SIZE: 1
+ EXP: True
+ EXP_DIM: 64
+ FRAME_LEN: 32
+ COEFF_LEN: 73
+ NUM_CLASSES: 46
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+ DEBUG: True
+
+
+MODEL:
+ AUDIOENCODER:
+ LEAKY_RELU: True
+ NORM: 'IN'
+ DISCRIMINATOR:
+ LEAKY_RELU: False
+ INPUT_CHANNELS: 6
+ CVAE:
+ AUDIO_EMB_IN_SIZE: 512
+ AUDIO_EMB_OUT_SIZE: 6
+ SEQ_LEN: 32
+ LATENT_SIZE: 64
+ ENCODER_LAYER_SIZES: [192, 128]
+ DECODER_LAYER_SIZES: [128, 192]
+
+
+TRAIN:
+ MAX_EPOCH: 150
+ GENERATOR:
+ LR: 1.0e-4
+ DISCRIMINATOR:
+ LR: 1.0e-4
+ LOSS:
+ LAMBDA_REG: 1
+ LAMBDA_LANDMARKS: 0
+ LAMBDA_VERTICES: 0
+ LAMBDA_GAN_MOTION: 0.7
+ LAMBDA_GAN_COEFF: 0
+ LAMBDA_KL: 1
+
+TAG:
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
+
+
diff --git a/src/config/facerender.yaml b/src/config/facerender.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9494ef82dfa16b16b7aa0b848ebdd6b23e739e2a
--- /dev/null
+++ b/src/config/facerender.yaml
@@ -0,0 +1,45 @@
+model_params:
+ common_params:
+ num_kp: 15
+ image_channel: 3
+ feature_channel: 32
+ estimate_jacobian: False # True
+ kp_detector_params:
+ temperature: 0.1
+ block_expansion: 32
+ max_features: 1024
+ scale_factor: 0.25 # 0.25
+ num_blocks: 5
+ reshape_channel: 16384 # 16384 = 1024 * 16
+ reshape_depth: 16
+ he_estimator_params:
+ block_expansion: 64
+ max_features: 2048
+ num_bins: 66
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 2
+ reshape_channel: 32
+ reshape_depth: 16 # 512 = 32 * 16
+ num_resblocks: 6
+ estimate_occlusion_map: True
+ dense_motion_params:
+ block_expansion: 32
+ max_features: 1024
+ num_blocks: 5
+ reshape_depth: 16
+ compress: 4
+ discriminator_params:
+ scales: [1]
+ block_expansion: 32
+ max_features: 512
+ num_blocks: 4
+ sn: True
+ mapping_params:
+ coeff_nc: 70
+ descriptor_nc: 1024
+ layer: 3
+ num_kp: 15
+ num_bins: 66
+
diff --git a/src/config/facerender_pirender.yaml b/src/config/facerender_pirender.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e4f1da2908b46f06a17822d12ba97a5cc5c5f369
--- /dev/null
+++ b/src/config/facerender_pirender.yaml
@@ -0,0 +1,83 @@
+# How often do you want to log the training stats.
+# network_list:
+# gen: gen_optimizer
+# dis: dis_optimizer
+
+distributed: False
+image_to_tensorboard: True
+snapshot_save_iter: 40000
+snapshot_save_epoch: 20
+snapshot_save_start_iter: 20000
+snapshot_save_start_epoch: 10
+image_save_iter: 1000
+max_epoch: 200
+logging_iter: 100
+results_dir: ./eval_results
+
+gen_optimizer:
+ type: adam
+ lr: 0.0001
+ adam_beta1: 0.5
+ adam_beta2: 0.999
+ lr_policy:
+ iteration_mode: True
+ type: step
+ step_size: 300000
+ gamma: 0.2
+
+trainer:
+ type: trainers.face_trainer::FaceTrainer
+ pretrain_warp_iteration: 200000
+ loss_weight:
+ weight_perceptual_warp: 2.5
+ weight_perceptual_final: 4
+ vgg_param_warp:
+ network: vgg19
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
+ use_style_loss: False
+ num_scales: 4
+ vgg_param_final:
+ network: vgg19
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
+ use_style_loss: True
+ num_scales: 4
+ style_to_perceptual: 250
+ init:
+ type: 'normal'
+ gain: 0.02
+gen:
+ type: generators.face_model::FaceGenerator
+ param:
+ mapping_net:
+ coeff_nc: 73
+ descriptor_nc: 256
+ layer: 3
+ warpping_net:
+ encoder_layer: 5
+ decoder_layer: 3
+ base_nc: 32
+ editing_net:
+ layer: 3
+ num_res_blocks: 2
+ base_nc: 64
+ common:
+ image_nc: 3
+ descriptor_nc: 256
+ max_nc: 256
+ use_spect: False
+
+
+# Data options.
+data:
+ type: data.vox_dataset::VoxDataset
+ path: ./dataset/vox_lmdb
+ resolution: 256
+ semantic_radius: 13
+ train:
+ batch_size: 5
+ distributed: True
+ val:
+ batch_size: 8
+ distributed: True
+
+
diff --git a/src/config/facerender_still.yaml b/src/config/facerender_still.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6b4d66dade3e655ac4cfc25a994ca28e53821d80
--- /dev/null
+++ b/src/config/facerender_still.yaml
@@ -0,0 +1,45 @@
+model_params:
+ common_params:
+ num_kp: 15
+ image_channel: 3
+ feature_channel: 32
+ estimate_jacobian: False # True
+ kp_detector_params:
+ temperature: 0.1
+ block_expansion: 32
+ max_features: 1024
+ scale_factor: 0.25 # 0.25
+ num_blocks: 5
+ reshape_channel: 16384 # 16384 = 1024 * 16
+ reshape_depth: 16
+ he_estimator_params:
+ block_expansion: 64
+ max_features: 2048
+ num_bins: 66
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 2
+ reshape_channel: 32
+ reshape_depth: 16 # 512 = 32 * 16
+ num_resblocks: 6
+ estimate_occlusion_map: True
+ dense_motion_params:
+ block_expansion: 32
+ max_features: 1024
+ num_blocks: 5
+ reshape_depth: 16
+ compress: 4
+ discriminator_params:
+ scales: [1]
+ block_expansion: 32
+ max_features: 512
+ num_blocks: 4
+ sn: True
+ mapping_params:
+ coeff_nc: 73
+ descriptor_nc: 1024
+ layer: 3
+ num_kp: 15
+ num_bins: 66
+
diff --git a/src/config/similarity_Lm3D_all.mat b/src/config/similarity_Lm3D_all.mat
new file mode 100644
index 0000000000000000000000000000000000000000..a0e23588302bc71fc899eef53ff06df5f4df4c1d
Binary files /dev/null and b/src/config/similarity_Lm3D_all.mat differ
diff --git a/src/face3d/data/__init__.py b/src/face3d/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9761c518a1b07c5996165869742af0a52c82bc
--- /dev/null
+++ b/src/face3d/data/__init__.py
@@ -0,0 +1,116 @@
+"""This package includes all the modules related to data loading and preprocessing
+
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
+ You need to implement four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point from data loader.
+ -- : (optionally) add dataset-specific options and set default options.
+
+Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
+See our template dataset class 'template_dataset.py' for more details.
+"""
+import numpy as np
+import importlib
+import torch.utils.data
+from face3d.data.base_dataset import BaseDataset
+
+
+def find_dataset_using_name(dataset_name):
+ """Import the module "data/[dataset_name]_dataset.py".
+
+ In the file, the class called DatasetNameDataset() will
+ be instantiated. It has to be a subclass of BaseDataset,
+ and it is case-insensitive.
+ """
+ dataset_filename = "data." + dataset_name + "_dataset"
+ datasetlib = importlib.import_module(dataset_filename)
+
+ dataset = None
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
+ for name, cls in datasetlib.__dict__.items():
+ if name.lower() == target_dataset_name.lower() \
+ and issubclass(cls, BaseDataset):
+ dataset = cls
+
+ if dataset is None:
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
+
+ return dataset
+
+
+def get_option_setter(dataset_name):
+ """Return the static method of the dataset class."""
+ dataset_class = find_dataset_using_name(dataset_name)
+ return dataset_class.modify_commandline_options
+
+
+def create_dataset(opt, rank=0):
+ """Create a dataset given the option.
+
+ This function wraps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from data import create_dataset
+ >>> dataset = create_dataset(opt)
+ """
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
+ dataset = data_loader.load_data()
+ return dataset
+
+class CustomDatasetDataLoader():
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
+
+ def __init__(self, opt, rank=0):
+ """Initialize this class
+
+ Step 1: create a dataset instance given the name [dataset_mode]
+ Step 2: create a multi-threaded data loader.
+ """
+ self.opt = opt
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
+ self.dataset = dataset_class(opt)
+ self.sampler = None
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
+ if opt.use_ddp and opt.isTrain:
+ world_size = opt.world_size
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
+ self.dataset,
+ num_replicas=world_size,
+ rank=rank,
+ shuffle=not opt.serial_batches
+ )
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ sampler=self.sampler,
+ num_workers=int(opt.num_threads / world_size),
+ batch_size=int(opt.batch_size / world_size),
+ drop_last=True)
+ else:
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batch_size,
+ shuffle=(not opt.serial_batches) and opt.isTrain,
+ num_workers=int(opt.num_threads),
+ drop_last=True
+ )
+
+ def set_epoch(self, epoch):
+ self.dataset.current_epoch = epoch
+ if self.sampler is not None:
+ self.sampler.set_epoch(epoch)
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ """Return the number of data in the dataset"""
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ """Return a batch of data"""
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/src/face3d/data/base_dataset.py b/src/face3d/data/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd57d082d519f512d7114b4f867b6695fb7de06
--- /dev/null
+++ b/src/face3d/data/base_dataset.py
@@ -0,0 +1,125 @@
+"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
+
+It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
+"""
+import random
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+from abc import ABC, abstractmethod
+
+
+class BaseDataset(data.Dataset, ABC):
+ """This class is an abstract base class (ABC) for datasets.
+
+ To create a subclass, you need to implement the following four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point.
+ -- : (optionally) add dataset-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the class; save the options in the class
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ self.opt = opt
+ # self.root = opt.dataroot
+ self.current_epoch = 0
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def __len__(self):
+ """Return the total number of images in the dataset."""
+ return 0
+
+ @abstractmethod
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index - - a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
+ """
+ pass
+
+
+def get_transform(grayscale=False):
+ transform_list = []
+ if grayscale:
+ transform_list.append(transforms.Grayscale(1))
+ transform_list += [transforms.ToTensor()]
+ return transforms.Compose(transform_list)
+
+def get_affine_mat(opt, size):
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
+ w, h = size
+
+ if 'shift' in opt.preprocess:
+ shift_pixs = int(opt.shift_pixs)
+ shift_x = random.randint(-shift_pixs, shift_pixs)
+ shift_y = random.randint(-shift_pixs, shift_pixs)
+ if 'scale' in opt.preprocess:
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
+ if 'rot' in opt.preprocess:
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
+ rot_rad = -rot_angle * np.pi/180
+ if 'flip' in opt.preprocess:
+ flip = random.random() > 0.5
+
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
+
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
+ affine_inv = np.linalg.inv(affine)
+ return affine, affine_inv, flip
+
+def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
+
+def apply_lm_affine(landmark, affine, flip, size):
+ _, h = size
+ lm = landmark.copy()
+ lm[:, 1] = h - 1 - lm[:, 1]
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
+ lm = lm @ np.transpose(affine)
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
+ lm = lm[:, :2]
+ lm[:, 1] = h - 1 - lm[:, 1]
+ if flip:
+ lm_ = lm.copy()
+ lm_[:17] = lm[16::-1]
+ lm_[17:22] = lm[26:21:-1]
+ lm_[22:27] = lm[21:16:-1]
+ lm_[31:36] = lm[35:30:-1]
+ lm_[36:40] = lm[45:41:-1]
+ lm_[40:42] = lm[47:45:-1]
+ lm_[42:46] = lm[39:35:-1]
+ lm_[46:48] = lm[41:39:-1]
+ lm_[48:55] = lm[54:47:-1]
+ lm_[55:60] = lm[59:54:-1]
+ lm_[60:65] = lm[64:59:-1]
+ lm_[65:68] = lm[67:64:-1]
+ lm = lm_
+ return lm
diff --git a/src/face3d/data/flist_dataset.py b/src/face3d/data/flist_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0b6945c80aa756074a5d3c02b9443b15ddcfc57
--- /dev/null
+++ b/src/face3d/data/flist_dataset.py
@@ -0,0 +1,125 @@
+"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
+"""
+
+import os.path
+from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+import util.util as util
+import numpy as np
+import json
+import torch
+from scipy.io import loadmat, savemat
+import pickle
+from util.preprocess import align_img, estimate_norm
+from util.load_mats import load_lm3d
+
+
+def default_flist_reader(flist):
+ """
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
+ """
+ imlist = []
+ with open(flist, 'r') as rf:
+ for line in rf.readlines():
+ impath = line.strip()
+ imlist.append(impath)
+
+ return imlist
+
+def jason_flist_reader(flist):
+ with open(flist, 'r') as fp:
+ info = json.load(fp)
+ return info
+
+def parse_label(label):
+ return torch.tensor(np.array(label).astype(np.float32))
+
+
+class FlistDataset(BaseDataset):
+ """
+ It requires one directories to host training images '/path/to/data/train'
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseDataset.__init__(self, opt)
+
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
+
+ msk_names = default_flist_reader(opt.flist)
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
+
+ self.size = len(self.msk_paths)
+ self.opt = opt
+
+ self.name = 'train' if opt.isTrain else 'val'
+ if '_' in opt.flist:
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
+
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index (int) -- a random integer for data indexing
+
+ Returns a dictionary that contains A, B, A_paths and B_paths
+ img (tensor) -- an image in the input domain
+ msk (tensor) -- its corresponding attention mask
+ lm (tensor) -- its corresponding 3d landmarks
+ im_paths (str) -- image paths
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
+ """
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
+ img_path = msk_path.replace('mask/', '')
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
+
+ raw_img = Image.open(img_path).convert('RGB')
+ raw_msk = Image.open(msk_path).convert('RGB')
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
+
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
+
+ aug_flag = self.opt.use_aug and self.opt.isTrain
+ if aug_flag:
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
+
+ _, H = img.size
+ M = estimate_norm(lm, H)
+ transform = get_transform()
+ img_tensor = transform(img)
+ msk_tensor = transform(msk)[:1, ...]
+ lm_tensor = parse_label(lm)
+ M_tensor = parse_label(M)
+
+
+ return {'imgs': img_tensor,
+ 'lms': lm_tensor,
+ 'msks': msk_tensor,
+ 'M': M_tensor,
+ 'im_paths': img_path,
+ 'aug_flag': aug_flag,
+ 'dataset': self.name}
+
+ def _augmentation(self, img, lm, opt, msk=None):
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
+ img = apply_img_affine(img, affine_inv)
+ lm = apply_lm_affine(lm, affine, flip, img.size)
+ if msk is not None:
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
+ return img, lm, msk
+
+
+
+
+ def __len__(self):
+ """Return the total number of images in the dataset.
+ """
+ return self.size
diff --git a/src/face3d/data/image_folder.py b/src/face3d/data/image_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..efadc2ecbe2fb4b53b78230aba25ec505eff0e55
--- /dev/null
+++ b/src/face3d/data/image_folder.py
@@ -0,0 +1,66 @@
+"""A modified image folder class
+
+We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
+so that this class can load images from both current directory and its subdirectories.
+"""
+import numpy as np
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+ '.tif', '.TIF', '.tiff', '.TIFF',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir, max_dataset_size=float("inf")):
+ images = []
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+ return images[:min(max_dataset_size, len(images))]
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/src/face3d/data/template_dataset.py b/src/face3d/data/template_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfdf16be2a8a834b204c45d88c86857b37b9bd25
--- /dev/null
+++ b/src/face3d/data/template_dataset.py
@@ -0,0 +1,75 @@
+"""Dataset class template
+
+This module provides a template for users to implement custom datasets.
+You can specify '--dataset_mode template' to use this dataset.
+The class name should be consistent with both the filename and its dataset_mode option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+You need to implement the following functions:
+ -- : Add dataset-specific options and rewrite default values for existing options.
+ -- <__init__>: Initialize this dataset class.
+ -- <__getitem__>: Return a data point and its metadata information.
+ -- <__len__>: Return the number of images.
+"""
+from data.base_dataset import BaseDataset, get_transform
+# from data.image_folder import make_dataset
+# from PIL import Image
+
+
+class TemplateDataset(BaseDataset):
+ """A template dataset class for you to implement custom datasets."""
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ A few things can be done here.
+ - save the options (have been done in BaseDataset)
+ - get image paths and meta information of the dataset.
+ - define the image transformation.
+ """
+ # save the option and dataset root
+ BaseDataset.__init__(self, opt)
+ # get the image paths of your dataset;
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
+ # define the default transform function. You can use ; You can also define your custom transform function
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index -- a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
+
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
+ Step 4: return a data point as a dictionary.
+ """
+ path = 'temp' # needs to be a string
+ data_A = None # needs to be a tensor
+ data_B = None # needs to be a tensor
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
+
+ def __len__(self):
+ """Return the total number of images."""
+ return len(self.image_paths)
diff --git a/src/face3d/extract_kp_videos.py b/src/face3d/extract_kp_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..21616a3b4b5077ffdce99621395237b4edcff58c
--- /dev/null
+++ b/src/face3d/extract_kp_videos.py
@@ -0,0 +1,108 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import face_alignment
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from itertools import cycle
+
+from torch.multiprocessing import Pool, Process, set_start_method
+
+class KeypointExtractor():
+ def __init__(self, device):
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
+ device=device)
+
+ def extract_keypoint(self, images, name=None, info=True):
+ if isinstance(images, list):
+ keypoints = []
+ if info:
+ i_range = tqdm(images,desc='landmark Det:')
+ else:
+ i_range = images
+
+ for image in i_range:
+ current_kp = self.extract_keypoint(image)
+ if np.mean(current_kp) == -1 and keypoints:
+ keypoints.append(keypoints[-1])
+ else:
+ keypoints.append(current_kp[None])
+
+ keypoints = np.concatenate(keypoints, 0)
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+ else:
+ while True:
+ try:
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
+ break
+ except RuntimeError as e:
+ if str(e).startswith('CUDA'):
+ print("Warning: out of memory, sleep for 1s")
+ time.sleep(1)
+ else:
+ print(e)
+ break
+ except TypeError:
+ print('No face detected in this image')
+ shape = [68, 2]
+ keypoints = -1. * np.ones(shape)
+ break
+ if name is not None:
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+
+def read_video(filename):
+ frames = []
+ cap = cv2.VideoCapture(filename)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame = Image.fromarray(frame)
+ frames.append(frame)
+ else:
+ break
+ cap.release()
+ return frames
+
+def run(data):
+ filename, opt, device = data
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
+ kp_extractor = KeypointExtractor()
+ images = read_video(filename)
+ name = filename.split('/')[-2:]
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+ kp_extractor.extract_keypoint(
+ images,
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
+ )
+
+if __name__ == '__main__':
+ set_start_method('spawn')
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+ parser.add_argument('--device_ids', type=str, default='0,1')
+ parser.add_argument('--workers', type=int, default=4)
+
+ opt = parser.parse_args()
+ filenames = list()
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+ extensions = VIDEO_EXTENSIONS
+
+ for ext in extensions:
+ os.listdir(f'{opt.input_dir}')
+ print(f'{opt.input_dir}/*.{ext}')
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+ print('Total number of videos:', len(filenames))
+ pool = Pool(opt.workers)
+ args_list = cycle([opt])
+ device_ids = opt.device_ids.split(",")
+ device_ids = cycle(device_ids)
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+ None
diff --git a/src/face3d/extract_kp_videos_safe.py b/src/face3d/extract_kp_videos_safe.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba3830b84bee98e02a7d0681803cc4b1719787c2
--- /dev/null
+++ b/src/face3d/extract_kp_videos_safe.py
@@ -0,0 +1,151 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import numpy as np
+from PIL import Image
+import torch
+from tqdm import tqdm
+from itertools import cycle
+from torch.multiprocessing import Pool, Process, set_start_method
+
+from facexlib.alignment import landmark_98_to_68
+from facexlib.detection import init_detection_model
+
+from facexlib.utils import load_file_from_url
+from facexlib.alignment.awing_arch import FAN
+
+def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
+ if model_name == 'awing_fan':
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
+ model.eval()
+ model = model.to(device)
+ return model
+
+
+class KeypointExtractor():
+ def __init__(self, device='cuda'):
+
+ ### gfpgan/weights
+ try:
+ import webui # in webui
+ root_path = 'extensions/SadTalker/gfpgan/weights'
+
+ except:
+ root_path = 'gfpgan/weights'
+
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
+
+ def extract_keypoint(self, images, name=None, info=True):
+ if isinstance(images, list):
+ keypoints = []
+ if info:
+ i_range = tqdm(images,desc='landmark Det:')
+ else:
+ i_range = images
+
+ for image in i_range:
+ current_kp = self.extract_keypoint(image)
+ # current_kp = self.detector.get_landmarks(np.array(image))
+ if np.mean(current_kp) == -1 and keypoints:
+ keypoints.append(keypoints[-1])
+ else:
+ keypoints.append(current_kp[None])
+
+ keypoints = np.concatenate(keypoints, 0)
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+ else:
+ while True:
+ try:
+ with torch.no_grad():
+ # face detection -> face alignment.
+ img = np.array(images)
+ bboxes = self.det_net.detect_faces(images, 0.97)
+
+ bboxes = bboxes[0]
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
+
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
+
+ #### keypoints to the original location
+ keypoints[:,0] += int(bboxes[0])
+ keypoints[:,1] += int(bboxes[1])
+
+ break
+ except RuntimeError as e:
+ if str(e).startswith('CUDA'):
+ print("Warning: out of memory, sleep for 1s")
+ time.sleep(1)
+ else:
+ print(e)
+ break
+ except TypeError:
+ print('No face detected in this image')
+ shape = [68, 2]
+ keypoints = -1. * np.ones(shape)
+ break
+ if name is not None:
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+
+def read_video(filename):
+ frames = []
+ cap = cv2.VideoCapture(filename)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame = Image.fromarray(frame)
+ frames.append(frame)
+ else:
+ break
+ cap.release()
+ return frames
+
+def run(data):
+ filename, opt, device = data
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
+ kp_extractor = KeypointExtractor()
+ images = read_video(filename)
+ name = filename.split('/')[-2:]
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+ kp_extractor.extract_keypoint(
+ images,
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
+ )
+
+if __name__ == '__main__':
+ set_start_method('spawn')
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+ parser.add_argument('--device_ids', type=str, default='0,1')
+ parser.add_argument('--workers', type=int, default=4)
+
+ opt = parser.parse_args()
+ filenames = list()
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+ extensions = VIDEO_EXTENSIONS
+
+ for ext in extensions:
+ os.listdir(f'{opt.input_dir}')
+ print(f'{opt.input_dir}/*.{ext}')
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+ print('Total number of videos:', len(filenames))
+ pool = Pool(opt.workers)
+ args_list = cycle([opt])
+ device_ids = opt.device_ids.split(",")
+ device_ids = cycle(device_ids)
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+ None
diff --git a/src/face3d/models/__init__.py b/src/face3d/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a7986c7ad2ec48f404adf81fea5aa06aaf1eeb4
--- /dev/null
+++ b/src/face3d/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from src.face3d.models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "face3d.models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/src/face3d/models/arcface_torch/README.md b/src/face3d/models/arcface_torch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2ee63a861229b68873561fa39bfa7c9a8b53b947
--- /dev/null
+++ b/src/face3d/models/arcface_torch/README.md
@@ -0,0 +1,164 @@
+# Distributed Arcface Training in Pytorch
+
+This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
+identity on a single server.
+
+## Requirements
+
+- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
+- `pip install -r requirements.txt`.
+- Download the dataset
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
+ .
+
+## How to Training
+
+To train a model, run `train.py` with the path to the configs:
+
+### 1. Single node, 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+```
+
+### 2. Multiple nodes, each node 8 GPUs:
+
+Node 0:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+Node 1:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+### 3.Training resnet2060 with 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
+```
+
+## Model Zoo
+
+- The models are available for non-commercial research purposes only.
+- All models can be found in here.
+- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
+- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
+
+### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
+
+ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
+recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
+As the result, we can evaluate the FAIR performance for different algorithms.
+
+For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
+globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
+
+For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
+Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
+There are totally 13,928 positive pairs and 96,983,824 negative pairs.
+
+| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
+| :---: | :--- | :--- | :--- |:--- |:--- |
+| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
+| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
+| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
+| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
+| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
+| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
+| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
+| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
+| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
+| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
+
+### Performance on IJB-C and Verification Datasets
+
+| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
+| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
+| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
+| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
+| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
+| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
+| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
+| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
+| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
+| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
+| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
+
+[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
+
+
+## [Speed Benchmark](docs/speed_benchmark.md)
+
+**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
+classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
+accuracy with several times faster training performance and smaller GPU memory.
+Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
+sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
+sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
+we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
+training and mixed precision training.
+
+![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
+
+More details see
+[speed_benchmark.md](docs/speed_benchmark.md) in docs.
+
+### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
+
+`-` means training failed because of gpu memory limitations.
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|1400000 | **1672** | 3043 | 4738 |
+|5500000 | **-** | **1389** | 3975 |
+|8000000 | **-** | **-** | 3565 |
+|16000000 | **-** | **-** | 2679 |
+|29000000 | **-** | **-** | **1855** |
+
+### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|1400000 | 32252 | 11178 | 6056 |
+|5500000 | **-** | 32188 | 9854 |
+|8000000 | **-** | **-** | 12310 |
+|16000000 | **-** | **-** | 19950 |
+|29000000 | **-** | **-** | 32324 |
+
+## Evaluation ICCV2021-MFR and IJB-C
+
+More details see [eval.md](docs/eval.md) in docs.
+
+## Test
+
+We tested many versions of PyTorch. Please create an issue if you are having trouble.
+
+- [x] torch 1.6.0
+- [x] torch 1.7.1
+- [x] torch 1.8.0
+- [x] torch 1.9.0
+
+## Citation
+
+```
+@inproceedings{deng2019arcface,
+ title={Arcface: Additive angular margin loss for deep face recognition},
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={4690--4699},
+ year={2019}
+}
+@inproceedings{an2020partical_fc,
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
+ Zhang, Debing and Fu Ying},
+ booktitle={Arxiv 2010.05222},
+ year={2020}
+}
+```
diff --git a/src/face3d/models/arcface_torch/backbones/__init__.py b/src/face3d/models/arcface_torch/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55bd4c5d1889a1a998b52eb56793bbc1eef1b691
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/__init__.py
@@ -0,0 +1,25 @@
+from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
+from .mobilefacenet import get_mbf
+
+
+def get_model(name, **kwargs):
+ # resnet
+ if name == "r18":
+ return iresnet18(False, **kwargs)
+ elif name == "r34":
+ return iresnet34(False, **kwargs)
+ elif name == "r50":
+ return iresnet50(False, **kwargs)
+ elif name == "r100":
+ return iresnet100(False, **kwargs)
+ elif name == "r200":
+ return iresnet200(False, **kwargs)
+ elif name == "r2060":
+ from .iresnet2060 import iresnet2060
+ return iresnet2060(False, **kwargs)
+ elif name == "mbf":
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf(fp16=fp16, num_features=num_features)
+ else:
+ raise ValueError()
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/backbones/iresnet.py b/src/face3d/models/arcface_torch/backbones/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6d3b9c240c24687d432197f976ee01fbf423216
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/iresnet.py
@@ -0,0 +1,187 @@
+import torch
+from torch import nn
+
+__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet18(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
+ progress, **kwargs)
+
+
+def iresnet34(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet50(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet100(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet200(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
+ progress, **kwargs)
+
diff --git a/src/face3d/models/arcface_torch/backbones/iresnet2060.py b/src/face3d/models/arcface_torch/backbones/iresnet2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..21d1122144d207637d2444cba1f68fe630c89f31
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/iresnet2060.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+
+assert torch.__version__ >= "1.8.1"
+from torch.utils.checkpoint import checkpoint_sequential
+
+__all__ = ['iresnet2060']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def checkpoint(self, func, num_seg, x):
+ if self.training:
+ return checkpoint_sequential(func, num_seg, x)
+ else:
+ return func(x)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.checkpoint(self.layer2, 20, x)
+ x = self.checkpoint(self.layer3, 100, x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet2060(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
diff --git a/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..87731491d76f9ff61cc70e57bb3f18c54fae308c
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
@@ -0,0 +1,130 @@
+'''
+Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
+Original author cavalleria
+'''
+
+import torch.nn as nn
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
+import torch
+
+
+class Flatten(Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class ConvBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(ConvBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
+ BatchNorm2d(num_features=out_c),
+ PReLU(num_parameters=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class LinearBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(LinearBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
+ BatchNorm2d(num_features=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DepthWise(Module):
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
+ super(DepthWise, self).__init__()
+ self.residual = residual
+ self.layers = nn.Sequential(
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
+ )
+
+ def forward(self, x):
+ short_cut = None
+ if self.residual:
+ short_cut = x
+ x = self.layers(x)
+ if self.residual:
+ output = short_cut + x
+ else:
+ output = x
+ return output
+
+
+class Residual(Module):
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
+ super(Residual, self).__init__()
+ modules = []
+ for _ in range(num_block):
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
+ self.layers = Sequential(*modules)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class GDC(Module):
+ def __init__(self, embedding_size):
+ super(GDC, self).__init__()
+ self.layers = nn.Sequential(
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
+ Flatten(),
+ Linear(512, embedding_size, bias=False),
+ BatchNorm1d(embedding_size))
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class MobileFaceNet(Module):
+ def __init__(self, fp16=False, num_features=512):
+ super(MobileFaceNet, self).__init__()
+ scale = 2
+ self.fp16 = fp16
+ self.layers = nn.Sequential(
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ )
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
+ self.features = GDC(num_features)
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.layers(x)
+ x = self.conv_sep(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def get_mbf(fp16, num_features):
+ return MobileFaceNet(fp16, num_features)
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/configs/3millions.py b/src/face3d/models/arcface_torch/configs/3millions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9edc2f1414e35f93abfd3dfe11a61f1f406580e
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/3millions.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/src/face3d/models/arcface_torch/configs/3millions_pfc.py b/src/face3d/models/arcface_torch/configs/3millions_pfc.py
new file mode 100644
index 0000000000000000000000000000000000000000..77caafdbb300d8109d5bfdb844f131710ef81f20
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/3millions_pfc.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/src/face3d/models/arcface_torch/configs/__init__.py b/src/face3d/models/arcface_torch/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/configs/base.py b/src/face3d/models/arcface_torch/configs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..78e4b36a9142b649ec39a8c59331bb2557f2ad57
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/base.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = "ms1mv3_arcface_r50"
+
+config.dataset = "ms1m-retinaface-t1"
+config.embedding_size = 512
+config.sample_rate = 1
+config.fp16 = False
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+if config.dataset == "emore":
+ config.rec = "/train_tmp/faces_emore"
+ config.num_classes = 85742
+ config.num_image = 5822653
+ config.num_epoch = 16
+ config.warmup_epoch = -1
+ config.decay_epoch = [8, 14, ]
+ config.val_targets = ["lfw", ]
+
+elif config.dataset == "ms1m-retinaface-t1":
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
+ config.num_classes = 93431
+ config.num_image = 5179510
+ config.num_epoch = 25
+ config.warmup_epoch = -1
+ config.decay_epoch = [11, 17, 22]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "glint360k":
+ config.rec = "/train_tmp/glint360k"
+ config.num_classes = 360232
+ config.num_image = 17091657
+ config.num_epoch = 20
+ config.warmup_epoch = -1
+ config.decay_epoch = [8, 12, 15, 18]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "webface":
+ config.rec = "/train_tmp/faces_webface_112x112"
+ config.num_classes = 10572
+ config.num_image = "forget"
+ config.num_epoch = 34
+ config.warmup_epoch = -1
+ config.decay_epoch = [20, 28, 32]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_mbf.py b/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ae777cc97af41a531cba4e5d1ff31f2efcb468
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r100.py b/src/face3d/models/arcface_torch/configs/glint360k_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..93d0701c0094517cec147c382b005e8063938548
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r100.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r18.py b/src/face3d/models/arcface_torch/configs/glint360k_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a8db34cd547e8e667103c93585296e47a894e97
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r34.py b/src/face3d/models/arcface_torch/configs/glint360k_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..fda2701758a839a7161d09c25f0ca3d26033baff
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r50.py b/src/face3d/models/arcface_torch/configs/glint360k_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e7922f1f63284e356dcc45a5f979f9c105f25e
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8a00d6305eeda5a94788017afc1cda0d4a4cd2a
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 20, 25]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb4e0d31f1aedf4590628d394e1606920fefb5c9
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..23ad81e082c4b6390b67b164d0ceb84bb0635684
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r2060"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 64
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f78337a3d1f9eb6e9145eb5093618796c6842d2
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..08ba55dbbea6df0afffddbb3d1ed173efad99604
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/speed.py b/src/face3d/models/arcface_torch/configs/speed.py
new file mode 100644
index 0000000000000000000000000000000000000000..45e95237da65e44f35a172c25ac6dc4e313e4eae
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/speed.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 100 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/src/face3d/models/arcface_torch/dataset.py b/src/face3d/models/arcface_torch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..96bbb8bb6da99122f350bc8e1a6390245840e32b
--- /dev/null
+++ b/src/face3d/models/arcface_torch/dataset.py
@@ -0,0 +1,124 @@
+import numbers
+import os
+import queue as Queue
+import threading
+
+import mxnet as mx
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, local_rank, max_prefetch=6):
+ super(BackgroundGenerator, self).__init__()
+ self.queue = Queue.Queue(max_prefetch)
+ self.generator = generator
+ self.local_rank = local_rank
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ torch.cuda.set_device(self.local_rank)
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
+
+
+class DataLoaderX(DataLoader):
+
+ def __init__(self, local_rank, **kwargs):
+ super(DataLoaderX, self).__init__(**kwargs)
+ self.stream = torch.cuda.Stream(local_rank)
+ self.local_rank = local_rank
+
+ def __iter__(self):
+ self.iter = super(DataLoaderX, self).__iter__()
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
+ self.preload()
+ return self
+
+ def preload(self):
+ self.batch = next(self.iter, None)
+ if self.batch is None:
+ return None
+ with torch.cuda.stream(self.stream):
+ for k in range(len(self.batch)):
+ self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+
+ def __next__(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is None:
+ raise StopIteration
+ self.preload()
+ return batch
+
+
+class MXFaceDataset(Dataset):
+ def __init__(self, root_dir, local_rank):
+ super(MXFaceDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [transforms.ToPILImage(),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ])
+ self.root_dir = root_dir
+ self.local_rank = local_rank
+ path_imgrec = os.path.join(root_dir, 'train.rec')
+ path_imgidx = os.path.join(root_dir, 'train.idx')
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
+ s = self.imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ if header.flag > 0:
+ self.header0 = (int(header.label[0]), int(header.label[1]))
+ self.imgidx = np.array(range(1, int(header.label[0])))
+ else:
+ self.imgidx = np.array(list(self.imgrec.keys))
+
+ def __getitem__(self, index):
+ idx = self.imgidx[index]
+ s = self.imgrec.read_idx(idx)
+ header, img = mx.recordio.unpack(s)
+ label = header.label
+ if not isinstance(label, numbers.Number):
+ label = label[0]
+ label = torch.tensor(label, dtype=torch.long)
+ sample = mx.image.imdecode(img).asnumpy()
+ if self.transform is not None:
+ sample = self.transform(sample)
+ return sample, label
+
+ def __len__(self):
+ return len(self.imgidx)
+
+
+class SyntheticDataset(Dataset):
+ def __init__(self, local_rank):
+ super(SyntheticDataset, self).__init__()
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).squeeze(0).float()
+ img = ((img / 255) - 0.5) / 0.5
+ self.img = img
+ self.label = 1
+
+ def __getitem__(self, index):
+ return self.img, self.label
+
+ def __len__(self):
+ return 1000000
diff --git a/src/face3d/models/arcface_torch/docs/eval.md b/src/face3d/models/arcface_torch/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..dd1d9e257367b6422680966198646c45e5a2671d
--- /dev/null
+++ b/src/face3d/models/arcface_torch/docs/eval.md
@@ -0,0 +1,31 @@
+## Eval on ICCV2021-MFR
+
+coming soon.
+
+
+## Eval IJBC
+You can eval ijbc with pytorch or onnx.
+
+
+1. Eval IJBC With Onnx
+```shell
+CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
+```
+
+2. Eval IJBC With Pytorch
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
+--model-prefix ms1mv3_arcface_r50/backbone.pth \
+--image-path IJB_release/IJBC \
+--result-dir ms1mv3_arcface_r50 \
+--batch-size 128 \
+--job ms1mv3_arcface_r50 \
+--target IJBC \
+--network iresnet50
+```
+
+## Inference
+
+```shell
+python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
+```
diff --git a/src/face3d/models/arcface_torch/docs/install.md b/src/face3d/models/arcface_torch/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..6314a40441285e9236438e468caf8b71a407531a
--- /dev/null
+++ b/src/face3d/models/arcface_torch/docs/install.md
@@ -0,0 +1,51 @@
+## v1.8.0
+### Linux and Windows
+```shell
+# CUDA 11.0
+pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
+
+# CPU only
+pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+```
+
+
+## v1.7.1
+### Linux and Windows
+```shell
+# CUDA 11.0
+pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
+
+# CUDA 10.1
+pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+```
+
+
+## v1.6.0
+
+### Linux and Windows
+```shell
+# CUDA 10.2
+pip install torch==1.6.0 torchvision==0.7.0
+
+# CUDA 10.1
+pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+```
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/docs/modelzoo.md b/src/face3d/models/arcface_torch/docs/modelzoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/docs/speed_benchmark.md b/src/face3d/models/arcface_torch/docs/speed_benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..055aee0defe2c43a523ced48260242f0f99b7cea
--- /dev/null
+++ b/src/face3d/models/arcface_torch/docs/speed_benchmark.md
@@ -0,0 +1,93 @@
+## Test Training Speed
+
+- Test Commands
+
+You need to use the following two commands to test the Partial FC training performance.
+The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
+batch size is 1024.
+```shell
+# Model Parallel
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
+# Partial FC 0.1
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
+```
+
+- GPU Memory
+
+```
+# (Model Parallel) gpustat -i
+[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
+[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
+[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
+[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
+[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
+[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
+[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
+[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
+
+# (Partial FC 0.1) gpustat -i
+[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
+[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
+[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
+[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
+[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
+[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
+[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
+[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
+```
+
+- Training Speed
+
+```python
+# (Model Parallel) trainging.log
+Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+
+# (Partial FC 0.1) trainging.log
+Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+```
+
+In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
+and the training speed is 2.5 times faster than the model parallel.
+
+
+## Speed Benchmark
+
+1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|250000 | 4047 | 4521 | 4976 |
+|500000 | 3087 | 4013 | 4900 |
+|1000000 | 2090 | 3449 | 4803 |
+|1400000 | 1672 | 3043 | 4738 |
+|2000000 | - | 2593 | 4626 |
+|4000000 | - | 1748 | 4208 |
+|5500000 | - | 1389 | 3975 |
+|8000000 | - | - | 3565 |
+|16000000 | - | - | 2679 |
+|29000000 | - | - | 1855 |
+
+2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|250000 | 9940 | 5826 | 5004 |
+|500000 | 14220 | 7114 | 5202 |
+|1000000 | 23708 | 9966 | 5620 |
+|1400000 | 32252 | 11178 | 6056 |
+|2000000 | - | 13978 | 6472 |
+|4000000 | - | 23238 | 8284 |
+|5500000 | - | 32188 | 9854 |
+|8000000 | - | - | 12310 |
+|16000000 | - | - | 19950 |
+|29000000 | - | - | 32324 |
diff --git a/src/face3d/models/arcface_torch/eval/__init__.py b/src/face3d/models/arcface_torch/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/eval/verification.py b/src/face3d/models/arcface_torch/eval/verification.py
new file mode 100644
index 0000000000000000000000000000000000000000..253343b83dbf9d1bd154d14ec068e098bf0968db
--- /dev/null
+++ b/src/face3d/models/arcface_torch/eval/verification.py
@@ -0,0 +1,407 @@
+"""Helper for evaluation on the Labeled Faces in the Wild dataset
+"""
+
+# MIT License
+#
+# Copyright (c) 2016 David Sandberg
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import datetime
+import os
+import pickle
+
+import mxnet as mx
+import numpy as np
+import sklearn
+import torch
+from mxnet import ndarray as nd
+from scipy import interpolate
+from sklearn.decomposition import PCA
+from sklearn.model_selection import KFold
+
+
+class LFold:
+ def __init__(self, n_splits=2, shuffle=False):
+ self.n_splits = n_splits
+ if self.n_splits > 1:
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
+
+ def split(self, indices):
+ if self.n_splits > 1:
+ return self.k_fold.split(indices)
+ else:
+ return [(indices, indices)]
+
+
+def calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ nrof_folds=10,
+ pca=0):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
+ accuracy = np.zeros((nrof_folds))
+ indices = np.arange(nrof_pairs)
+
+ if pca == 0:
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+ if pca > 0:
+ print('doing pca on', fold_idx)
+ embed1_train = embeddings1[train_set]
+ embed2_train = embeddings2[train_set]
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
+ pca_model = PCA(n_components=pca)
+ pca_model.fit(_embed_train)
+ embed1 = pca_model.transform(embeddings1)
+ embed2 = pca_model.transform(embeddings2)
+ embed1 = sklearn.preprocessing.normalize(embed1)
+ embed2 = sklearn.preprocessing.normalize(embed2)
+ diff = np.subtract(embed1, embed2)
+ dist = np.sum(np.square(diff), 1)
+
+ # Find the best threshold for the fold
+ acc_train = np.zeros((nrof_thresholds))
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, _, acc_train[threshold_idx] = calculate_accuracy(
+ threshold, dist[train_set], actual_issame[train_set])
+ best_threshold_index = np.argmax(acc_train)
+ for threshold_idx, threshold in enumerate(thresholds):
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
+ threshold, dist[test_set],
+ actual_issame[test_set])
+ _, _, accuracy[fold_idx] = calculate_accuracy(
+ thresholds[best_threshold_index], dist[test_set],
+ actual_issame[test_set])
+
+ tpr = np.mean(tprs, 0)
+ fpr = np.mean(fprs, 0)
+ return tpr, fpr, accuracy
+
+
+def calculate_accuracy(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ tn = np.sum(
+ np.logical_and(np.logical_not(predict_issame),
+ np.logical_not(actual_issame)))
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
+
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
+ acc = float(tp + tn) / dist.size
+ return tpr, fpr, acc
+
+
+def calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ far_target,
+ nrof_folds=10):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ val = np.zeros(nrof_folds)
+ far = np.zeros(nrof_folds)
+
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+ indices = np.arange(nrof_pairs)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+
+ # Find the threshold that gives FAR = far_target
+ far_train = np.zeros(nrof_thresholds)
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, far_train[threshold_idx] = calculate_val_far(
+ threshold, dist[train_set], actual_issame[train_set])
+ if np.max(far_train) >= far_target:
+ f = interpolate.interp1d(far_train, thresholds, kind='slinear')
+ threshold = f(far_target)
+ else:
+ threshold = 0.0
+
+ val[fold_idx], far[fold_idx] = calculate_val_far(
+ threshold, dist[test_set], actual_issame[test_set])
+
+ val_mean = np.mean(val)
+ far_mean = np.mean(far)
+ val_std = np.std(val)
+ return val_mean, val_std, far_mean
+
+
+def calculate_val_far(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
+ false_accept = np.sum(
+ np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ n_same = np.sum(actual_issame)
+ n_diff = np.sum(np.logical_not(actual_issame))
+ # print(true_accept, false_accept)
+ # print(n_same, n_diff)
+ val = float(true_accept) / float(n_same)
+ far = float(false_accept) / float(n_diff)
+ return val, far
+
+
+def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
+ # Calculate evaluation metrics
+ thresholds = np.arange(0, 4, 0.01)
+ embeddings1 = embeddings[0::2]
+ embeddings2 = embeddings[1::2]
+ tpr, fpr, accuracy = calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ nrof_folds=nrof_folds,
+ pca=pca)
+ thresholds = np.arange(0, 4, 0.001)
+ val, val_std, far = calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ 1e-3,
+ nrof_folds=nrof_folds)
+ return tpr, fpr, accuracy, val, val_std, far
+
+@torch.no_grad()
+def load_bin(path, image_size):
+ try:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f) # py2
+ except UnicodeDecodeError as e:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f, encoding='bytes') # py3
+ data_list = []
+ for flip in [0, 1]:
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
+ data_list.append(data)
+ for idx in range(len(issame_list) * 2):
+ _bin = bins[idx]
+ img = mx.image.imdecode(_bin)
+ if img.shape[1] != image_size[0]:
+ img = mx.image.resize_short(img, image_size[0])
+ img = nd.transpose(img, axes=(2, 0, 1))
+ for flip in [0, 1]:
+ if flip == 1:
+ img = mx.ndarray.flip(data=img, axis=2)
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
+ if idx % 1000 == 0:
+ print('loading bin', idx)
+ print(data_list[0].shape)
+ return data_list, issame_list
+
+@torch.no_grad()
+def test(data_set, backbone, batch_size, nfolds=10):
+ print('testing verification..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+ _data = data[bb - batch_size: bb]
+ time0 = datetime.datetime.now()
+ img = ((_data / 255) - 0.5) / 0.5
+ net_out: torch.Tensor = backbone(img)
+ _embeddings = net_out.detach().cpu().numpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+
+ _xnorm = 0.0
+ _xnorm_cnt = 0
+ for embed in embeddings_list:
+ for i in range(embed.shape[0]):
+ _em = embed[i]
+ _norm = np.linalg.norm(_em)
+ _xnorm += _norm
+ _xnorm_cnt += 1
+ _xnorm /= _xnorm_cnt
+
+ acc1 = 0.0
+ std1 = 0.0
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ print(embeddings.shape)
+ print('infer time', time_consumed)
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
+
+
+def dumpR(data_set,
+ backbone,
+ batch_size,
+ name='',
+ data_extra=None,
+ label_shape=None):
+ print('dump verification embedding..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
+ time0 = datetime.datetime.now()
+ if data_extra is None:
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
+ else:
+ db = mx.io.DataBatch(data=(_data, _data_extra),
+ label=(_label,))
+ model.forward(db, is_train=False)
+ net_out = model.get_outputs()
+ _embeddings = net_out[0].asnumpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ actual_issame = np.asarray(issame_list)
+ outname = os.path.join('temp.bin')
+ with open(outname, 'wb') as f:
+ pickle.dump((embeddings, issame_list),
+ f,
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+
+# if __name__ == '__main__':
+#
+# parser = argparse.ArgumentParser(description='do verification')
+# # general
+# parser.add_argument('--data-dir', default='', help='')
+# parser.add_argument('--model',
+# default='../model/softmax,50',
+# help='path to load model.')
+# parser.add_argument('--target',
+# default='lfw,cfp_ff,cfp_fp,agedb_30',
+# help='test targets.')
+# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
+# parser.add_argument('--batch-size', default=32, type=int, help='')
+# parser.add_argument('--max', default='', type=str, help='')
+# parser.add_argument('--mode', default=0, type=int, help='')
+# parser.add_argument('--nfolds', default=10, type=int, help='')
+# args = parser.parse_args()
+# image_size = [112, 112]
+# print('image_size', image_size)
+# ctx = mx.gpu(args.gpu)
+# nets = []
+# vec = args.model.split(',')
+# prefix = args.model.split(',')[0]
+# epochs = []
+# if len(vec) == 1:
+# pdir = os.path.dirname(prefix)
+# for fname in os.listdir(pdir):
+# if not fname.endswith('.params'):
+# continue
+# _file = os.path.join(pdir, fname)
+# if _file.startswith(prefix):
+# epoch = int(fname.split('.')[0].split('-')[1])
+# epochs.append(epoch)
+# epochs = sorted(epochs, reverse=True)
+# if len(args.max) > 0:
+# _max = [int(x) for x in args.max.split(',')]
+# assert len(_max) == 2
+# if len(epochs) > _max[1]:
+# epochs = epochs[_max[0]:_max[1]]
+#
+# else:
+# epochs = [int(x) for x in vec[1].split('|')]
+# print('model number', len(epochs))
+# time0 = datetime.datetime.now()
+# for epoch in epochs:
+# print('loading', prefix, epoch)
+# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
+# all_layers = sym.get_internals()
+# sym = all_layers['fc1_output']
+# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
+# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
+# image_size[1]))])
+# model.set_params(arg_params, aux_params)
+# nets.append(model)
+# time_now = datetime.datetime.now()
+# diff = time_now - time0
+# print('model loading time', diff.total_seconds())
+#
+# ver_list = []
+# ver_name_list = []
+# for name in args.target.split(','):
+# path = os.path.join(args.data_dir, name + ".bin")
+# if os.path.exists(path):
+# print('loading.. ', name)
+# data_set = load_bin(path, image_size)
+# ver_list.append(data_set)
+# ver_name_list.append(name)
+#
+# if args.mode == 0:
+# for i in range(len(ver_list)):
+# results = []
+# for model in nets:
+# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
+# ver_list[i], model, args.batch_size, args.nfolds)
+# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
+# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
+# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
+# results.append(acc2)
+# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
+# elif args.mode == 1:
+# raise ValueError
+# else:
+# model = nets[0]
+# dumpR(ver_list[0], model, args.batch_size, args.target)
diff --git a/src/face3d/models/arcface_torch/eval_ijbc.py b/src/face3d/models/arcface_torch/eval_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c5a650d486d18eb02d6f60d448fc3b315261f5d
--- /dev/null
+++ b/src/face3d/models/arcface_torch/eval_ijbc.py
@@ -0,0 +1,483 @@
+# coding: utf-8
+
+import os
+import pickle
+
+import matplotlib
+import pandas as pd
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import timeit
+import sklearn
+import argparse
+import cv2
+import numpy as np
+import torch
+from skimage import transform as trans
+from backbones import get_model
+from sklearn.metrics import roc_curve, auc
+
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from pathlib import Path
+
+import sys
+import warnings
+
+sys.path.insert(0, "../")
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser(description='do ijb test')
+# general
+parser.add_argument('--model-prefix', default='', help='path to load model.')
+parser.add_argument('--image-path', default='', type=str, help='')
+parser.add_argument('--result-dir', default='.', type=str, help='')
+parser.add_argument('--batch-size', default=128, type=int, help='')
+parser.add_argument('--network', default='iresnet50', type=str, help='')
+parser.add_argument('--job', default='insightface', type=str, help='job name')
+parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+args = parser.parse_args()
+
+target = args.target
+model_path = args.model_prefix
+image_path = args.image_path
+result_dir = args.result_dir
+gpu_id = None
+use_norm_score = True # if Ture, TestMode(N1)
+use_detector_score = True # if Ture, TestMode(D1)
+use_flip_test = True # if Ture, TestMode(F1)
+job = args.job
+batch_size = args.batch_size
+
+
+class Embedding(object):
+ def __init__(self, prefix, data_shape, batch_size=1):
+ image_size = (112, 112)
+ self.image_size = image_size
+ weight = torch.load(prefix)
+ resnet = get_model(args.network, dropout=0, fp16=False).cuda()
+ resnet.load_state_dict(weight)
+ model = torch.nn.DataParallel(resnet)
+ self.model = model
+ self.model.eval()
+ src = np.array([
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]], dtype=np.float32)
+ src[:, 0] += 8.0
+ self.src = src
+ self.batch_size = batch_size
+ self.data_shape = data_shape
+
+ def get(self, rimg, landmark):
+
+ assert landmark.shape[0] == 68 or landmark.shape[0] == 5
+ assert landmark.shape[1] == 2
+ if landmark.shape[0] == 68:
+ landmark5 = np.zeros((5, 2), dtype=np.float32)
+ landmark5[0] = (landmark[36] + landmark[39]) / 2
+ landmark5[1] = (landmark[42] + landmark[45]) / 2
+ landmark5[2] = landmark[30]
+ landmark5[3] = landmark[48]
+ landmark5[4] = landmark[54]
+ else:
+ landmark5 = landmark
+ tform = trans.SimilarityTransform()
+ tform.estimate(landmark5, self.src)
+ M = tform.params[0:2, :]
+ img = cv2.warpAffine(rimg,
+ M, (self.image_size[1], self.image_size[0]),
+ borderValue=0.0)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_flip = np.fliplr(img)
+ img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
+ img_flip = np.transpose(img_flip, (2, 0, 1))
+ input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
+ input_blob[0] = img
+ input_blob[1] = img_flip
+ return input_blob
+
+ @torch.no_grad()
+ def forward_db(self, batch_data):
+ imgs = torch.Tensor(batch_data).cuda()
+ imgs.div_(255).sub_(0.5).div_(0.5)
+ feat = self.model(imgs)
+ feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
+ return feat.cpu().numpy()
+
+
+# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
+def divideIntoNstrand(listTemp, n):
+ twoList = [[] for i in range(n)]
+ for i, e in enumerate(listTemp):
+ twoList[i % n].append(e)
+ return twoList
+
+
+def read_template_media_list(path):
+ # ijb_meta = np.loadtxt(path, dtype=str)
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+# In[ ]:
+
+
+def read_template_pair_list(path):
+ # pairs = np.loadtxt(path, dtype=str)
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ # print(pairs.shape)
+ # print(pairs[:, 0].astype(np.int))
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+# In[ ]:
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# In[ ]:
+
+
+def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
+ batch_size = args.batch_size
+ data_shape = (3, 112, 112)
+
+ files = files_list
+ print('files:', len(files))
+ rare_size = len(files) % batch_size
+ faceness_scores = []
+ batch = 0
+ img_feats = np.empty((len(files), 1024), dtype=np.float32)
+
+ batch_data = np.empty((2 * batch_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, batch_size)
+ for img_index, each_line in enumerate(files[:len(files) - rare_size]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+
+ batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
+ batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
+ if (img_index + 1) % batch_size == 0:
+ print('batch', batch)
+ img_feats[batch * batch_size:batch * batch_size +
+ batch_size][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+
+ batch_data = np.empty((2 * rare_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, rare_size)
+ for img_index, each_line in enumerate(files[len(files) - rare_size:]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+ batch_data[2 * img_index][:] = input_blob[0]
+ batch_data[2 * img_index + 1][:] = input_blob[1]
+ if (img_index + 1) % rare_size == 0:
+ print('batch', batch)
+ img_feats[len(files) -
+ rare_size:][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
+ # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
+ return img_feats, faceness_scores
+
+
+# In[ ]:
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ # ==========================================================
+ # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
+ # 2. compute media feature.
+ # 3. compute template feature.
+ # ==========================================================
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+
+ for count_template, uqt in enumerate(unique_templates):
+
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias,
+ return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [
+ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
+ ]
+ media_norm_feats = np.array(media_norm_feats)
+ # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
+ template_norm_feats = sklearn.preprocessing.normalize(template_feats)
+ # print(template_norm_feats.shape)
+ return template_norm_feats, unique_templates
+
+
+# In[ ]:
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ # ==========================================================
+ # Compute set-to-set Similarity Score.
+ # ==========================================================
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+# In[ ]:
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def read_score(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# # Step1: Load Meta Data
+
+# In[ ]:
+
+assert target == 'IJBC' or target == 'IJBB'
+
+# =============================================================
+# load image and template relationships for template feature embedding
+# tid --> template id, mid --> media id
+# format:
+# image_name tid mid
+# =============================================================
+start = timeit.default_timer()
+templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_face_tid_mid.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+
+# =============================================================
+# load template pairs for template-to-template verification
+# tid : template id, label : 1/0
+# format:
+# tid_1 tid_2 label
+# =============================================================
+start = timeit.default_timer()
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 2: Get Image Features
+
+# In[ ]:
+
+# =============================================================
+# load image features
+# format:
+# img_feats: [image_num x feats_dim] (227630, 512)
+# =============================================================
+start = timeit.default_timer()
+img_path = '%s/loose_crop' % image_path
+img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
+img_list = open(img_list_path)
+files = img_list.readlines()
+# files_list = divideIntoNstrand(files, rank_size)
+files_list = files
+
+# img_feats
+# for i in range(rank_size):
+img_feats, faceness_scores = get_image_feature(img_path, files_list,
+ model_path, 0, gpu_id)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
+ img_feats.shape[1]))
+
+# # Step3: Get Template Features
+
+# In[ ]:
+
+# =============================================================
+# compute template features from image features.
+# =============================================================
+start = timeit.default_timer()
+# ==========================================================
+# Norm feature before aggregation into template feature?
+# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
+# ==========================================================
+# 1. FaceScore (Feature Norm)
+# 2. FaceScore (Detector)
+
+if use_flip_test:
+ # concat --- F1
+ # img_input_feats = img_feats
+ # add --- F2
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] //
+ 2] + img_feats[:, img_feats.shape[1] // 2:]
+else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+if use_norm_score:
+ img_input_feats = img_input_feats
+else:
+ # normalise features to remove norm information
+ img_input_feats = img_input_feats / np.sqrt(
+ np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+else:
+ img_input_feats = img_input_feats
+
+template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 4: Get Template Similarity Scores
+
+# In[ ]:
+
+# =============================================================
+# compute verification scores between template pairs.
+# =============================================================
+start = timeit.default_timer()
+score = verification(template_norm_feats, unique_templates, p1, p2)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+save_path = os.path.join(result_dir, args.job)
+# save_path = result_dir + '/%s_result' % target
+
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
+np.save(score_save_file, score)
+
+# # Step 5: Get ROC Curves and TPR@FPR Table
+
+# In[ ]:
+
+files = [score_save_file]
+methods = []
+scores = []
+for file in files:
+ methods.append(Path(file).stem)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
+print(tpr_fpr_table)
diff --git a/src/face3d/models/arcface_torch/inference.py b/src/face3d/models/arcface_torch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e5156e8d649954837e397c2ff15ec29995e7502
--- /dev/null
+++ b/src/face3d/models/arcface_torch/inference.py
@@ -0,0 +1,35 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+
+from backbones import get_model
+
+
+@torch.no_grad()
+def inference(weight, name, img):
+ if img is None:
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
+ else:
+ img = cv2.imread(img)
+ img = cv2.resize(img, (112, 112))
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+ img.div_(255).sub_(0.5).div_(0.5)
+ net = get_model(name, fp16=False)
+ net.load_state_dict(torch.load(weight))
+ net.eval()
+ feat = net(img).numpy()
+ print(feat)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+ parser.add_argument('--network', type=str, default='r50', help='backbone network')
+ parser.add_argument('--weight', type=str, default='')
+ parser.add_argument('--img', type=str, default=None)
+ args = parser.parse_args()
+ inference(args.weight, args.network, args.img)
diff --git a/src/face3d/models/arcface_torch/losses.py b/src/face3d/models/arcface_torch/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..87aeaa107af4d53f5a6132b3739d5cafdcded7fc
--- /dev/null
+++ b/src/face3d/models/arcface_torch/losses.py
@@ -0,0 +1,42 @@
+import torch
+from torch import nn
+
+
+def get_loss(name):
+ if name == "cosface":
+ return CosFace()
+ elif name == "arcface":
+ return ArcFace()
+ else:
+ raise ValueError()
+
+
+class CosFace(nn.Module):
+ def __init__(self, s=64.0, m=0.40):
+ super(CosFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, cosine, label):
+ index = torch.where(label != -1)[0]
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+ m_hot.scatter_(1, label[index, None], self.m)
+ cosine[index] -= m_hot
+ ret = cosine * self.s
+ return ret
+
+
+class ArcFace(nn.Module):
+ def __init__(self, s=64.0, m=0.5):
+ super(ArcFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, cosine: torch.Tensor, label):
+ index = torch.where(label != -1)[0]
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+ m_hot.scatter_(1, label[index, None], self.m)
+ cosine.acos_()
+ cosine[index] += m_hot
+ cosine.cos_().mul_(self.s)
+ return cosine
diff --git a/src/face3d/models/arcface_torch/onnx_helper.py b/src/face3d/models/arcface_torch/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca922ca6d410655029e459cf8fd1c323d276c34c
--- /dev/null
+++ b/src/face3d/models/arcface_torch/onnx_helper.py
@@ -0,0 +1,250 @@
+from __future__ import division
+import datetime
+import os
+import os.path as osp
+import glob
+import numpy as np
+import cv2
+import sys
+import onnxruntime
+import onnx
+import argparse
+from onnx import numpy_helper
+from insightface.data import get_image
+
+class ArcFaceORT:
+ def __init__(self, model_path, cpu=False):
+ self.model_path = model_path
+ # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
+ self.providers = ['CPUExecutionProvider'] if cpu else None
+
+ #input_size is (w,h), return error message, return None if success
+ def check(self, track='cfat', test_img = None):
+ #default is cfat
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=15
+ if track.startswith('ms1m'):
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=10
+ elif track.startswith('glint'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=20
+ elif track.startswith('cfat'):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ elif track.startswith('unconstrained'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=30
+ else:
+ return "track not found"
+
+ if not os.path.exists(self.model_path):
+ return "model_path not exists"
+ if not os.path.isdir(self.model_path):
+ return "model_path should be directory"
+ onnx_files = []
+ for _file in os.listdir(self.model_path):
+ if _file.endswith('.onnx'):
+ onnx_files.append(osp.join(self.model_path, _file))
+ if len(onnx_files)==0:
+ return "do not have onnx files"
+ self.model_file = sorted(onnx_files)[-1]
+ print('use onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('input-shape:', input_shape)
+ if len(input_shape)!=4:
+ return "length of input_shape should be 4"
+ if not isinstance(input_shape[0], str):
+ #return "input_shape[0] should be str to support batch-inference"
+ print('reset input-shape[0] to None')
+ model = onnx.load(self.model_file)
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')
+ onnx.save(model, new_model_file)
+ self.model_file = new_model_file
+ print('use new onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('new-input-shape:', input_shape)
+
+ self.image_size = tuple(input_shape[2:4][::-1])
+ #print('image_size:', self.image_size)
+ input_name = input_cfg.name
+ outputs = session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ #print(o.name, o.shape)
+ if len(output_names)!=1:
+ return "number of output nodes should be 1"
+ self.session = session
+ self.input_name = input_name
+ self.output_names = output_names
+ #print(self.output_names)
+ model = onnx.load(self.model_file)
+ graph = model.graph
+ if len(graph.node)<8:
+ return "too small onnx graph"
+
+ input_size = (112,112)
+ self.crop = None
+ if track=='cfat':
+ crop_file = osp.join(self.model_path, 'crop.txt')
+ if osp.exists(crop_file):
+ lines = open(crop_file,'r').readlines()
+ if len(lines)!=6:
+ return "crop.txt should contain 6 lines"
+ lines = [int(x) for x in lines]
+ self.crop = lines[:4]
+ input_size = tuple(lines[4:6])
+ if input_size!=self.image_size:
+ return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size)
+
+ self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)
+ if self.model_size_mb > max_model_size_mb:
+ return "max model size exceed, given %.3f-MB"%self.model_size_mb
+
+ input_mean = None
+ input_std = None
+ if track=='cfat':
+ pn_file = osp.join(self.model_path, 'pixel_norm.txt')
+ if osp.exists(pn_file):
+ lines = open(pn_file,'r').readlines()
+ if len(lines)!=2:
+ return "pixel_norm.txt should contain 2 lines"
+ input_mean = float(lines[0])
+ input_std = float(lines[1])
+ if input_mean is not None or input_std is not None:
+ if input_mean is None or input_std is None:
+ return "please set input_mean and input_std simultaneously"
+ else:
+ find_sub = False
+ find_mul = False
+ for nid, node in enumerate(graph.node[:8]):
+ print(nid, node.name)
+ if node.name.startswith('Sub') or node.name.startswith('_minus'):
+ find_sub = True
+ if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):
+ find_mul = True
+ if find_sub and find_mul:
+ print("find sub and mul")
+ #mxnet arcface model
+ input_mean = 0.0
+ input_std = 1.0
+ else:
+ input_mean = 127.5
+ input_std = 127.5
+ self.input_mean = input_mean
+ self.input_std = input_std
+ for initn in graph.initializer:
+ weight_array = numpy_helper.to_array(initn)
+ dt = weight_array.dtype
+ if dt.itemsize<4:
+ return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)
+ if test_img is None:
+ test_img = get_image('Tom_Hanks_54745')
+ test_img = cv2.resize(test_img, self.image_size)
+ else:
+ test_img = cv2.resize(test_img, self.image_size)
+ feat, cost = self.benchmark(test_img)
+ batch_result = self.check_batch(test_img)
+ batch_result_sum = float(np.sum(batch_result))
+ if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:
+ print(batch_result)
+ print(batch_result_sum)
+ return "batch result output contains NaN!"
+
+ if len(feat.shape) < 2:
+ return "the shape of the feature must be two, but get {}".format(str(feat.shape))
+
+ if feat.shape[1] > max_feat_dim:
+ return "max feat dim exceed, given %d"%feat.shape[1]
+ self.feat_dim = feat.shape[1]
+ cost_ms = cost*1000
+ if cost_ms>max_time_cost:
+ return "max time cost exceed, given %.4f"%cost_ms
+ self.cost_ms = cost_ms
+ print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))
+ return None
+
+ def check_batch(self, img):
+ if not isinstance(img, list):
+ imgs = [img, ] * 32
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]
+ if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
+ nimg = cv2.resize(nimg, self.image_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,
+ mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+
+ def meta_info(self):
+ return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}
+
+
+ def forward(self, imgs):
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ input_size = self.image_size
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ return net_out
+
+ def benchmark(self, img):
+ input_size = self.image_size
+ if self.crop is not None:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ img = nimg
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ costs = []
+ for _ in range(50):
+ ta = datetime.datetime.now()
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ tb = datetime.datetime.now()
+ cost = (tb-ta).total_seconds()
+ costs.append(cost)
+ costs = sorted(costs)
+ cost = costs[5]
+ return net_out, cost
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='')
+ # general
+ parser.add_argument('workdir', help='submitted work dir', type=str)
+ parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')
+ args = parser.parse_args()
+ handler = ArcFaceORT(args.workdir)
+ err = handler.check(args.track)
+ print('err:', err)
diff --git a/src/face3d/models/arcface_torch/onnx_ijbc.py b/src/face3d/models/arcface_torch/onnx_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..05b50bfad4b4cf38903b89f596263a8e29a50d3e
--- /dev/null
+++ b/src/face3d/models/arcface_torch/onnx_ijbc.py
@@ -0,0 +1,267 @@
+import argparse
+import os
+import pickle
+import timeit
+
+import cv2
+import mxnet as mx
+import numpy as np
+import pandas as pd
+import prettytable
+import skimage.transform
+from sklearn.metrics import roc_curve
+from sklearn.preprocessing import normalize
+
+from onnx_helper import ArcFaceORT
+
+SRC = np.array(
+ [
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]]
+ , dtype=np.float32)
+SRC[:, 0] += 8.0
+
+
+class AlignedDataSet(mx.gluon.data.Dataset):
+ def __init__(self, root, lines, align=True):
+ self.lines = lines
+ self.root = root
+ self.align = align
+
+ def __len__(self):
+ return len(self.lines)
+
+ def __getitem__(self, idx):
+ each_line = self.lines[idx]
+ name_lmk_score = each_line.strip().split(' ')
+ name = os.path.join(self.root, name_lmk_score[0])
+ img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
+ landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
+ st = skimage.transform.SimilarityTransform()
+ st.estimate(landmark5, SRC)
+ img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
+ img_1 = np.expand_dims(img, 0)
+ img_2 = np.expand_dims(np.fliplr(img), 0)
+ output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
+ output = np.transpose(output, (0, 3, 1, 2))
+ output = mx.nd.array(output)
+ return output
+
+
+def extract(model_root, dataset):
+ model = ArcFaceORT(model_path=model_root)
+ model.check()
+ feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
+
+ def batchify_fn(data):
+ return mx.nd.concat(*data, dim=0)
+
+ data_loader = mx.gluon.data.DataLoader(
+ dataset, 128, last_batch='keep', num_workers=4,
+ thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
+ num_iter = 0
+ for batch in data_loader:
+ batch = batch.asnumpy()
+ batch = (batch - model.input_mean) / model.input_std
+ feat = model.session.run(model.output_names, {model.input_name: batch})[0]
+ feat = np.reshape(feat, (-1, model.feat_dim * 2))
+ feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
+ num_iter += 1
+ if num_iter % 50 == 0:
+ print(num_iter)
+ return feat_mat
+
+
+def read_template_media_list(path):
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+def image2template_feature(img_feats=None,
+ templates=None,
+ medias=None):
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+ for count_template, uqt in enumerate(unique_templates):
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
+ media_norm_feats = np.array(media_norm_feats)
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ template_norm_feats = normalize(template_feats)
+ return template_norm_feats, unique_templates
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),))
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000
+ sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def main(args):
+ use_norm_score = True # if Ture, TestMode(N1)
+ use_detector_score = True # if Ture, TestMode(D1)
+ use_flip_test = True # if Ture, TestMode(F1)
+ assert args.target == 'IJBC' or args.target == 'IJBB'
+
+ start = timeit.default_timer()
+ templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % args.image_path,
+ '%s_template_pair_label.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ img_path = '%s/loose_crop' % args.image_path
+ img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
+ img_list = open(img_list_path)
+ files = img_list.readlines()
+ dataset = AlignedDataSet(root=img_path, lines=files, align=True)
+ img_feats = extract(args.model_root, dataset)
+
+ faceness_scores = []
+ for each_line in files:
+ name_lmk_score = each_line.split()
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
+ start = timeit.default_timer()
+
+ if use_flip_test:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
+ else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+ if use_norm_score:
+ img_input_feats = img_input_feats
+ else:
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+ if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+ else:
+ img_input_feats = img_input_feats
+
+ template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ score = verification(template_norm_feats, unique_templates, p1, p2)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
+ np.save(score_save_file, score)
+ files = [score_save_file]
+ methods = []
+ scores = []
+ for file in files:
+ methods.append(os.path.basename(file))
+ scores.append(np.load(file))
+ methods = np.array(methods)
+ scores = dict(zip(methods, scores))
+ x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+ tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
+ for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr)
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, args.target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+ print(tpr_fpr_table)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='do ijb test')
+ # general
+ parser.add_argument('--model-root', default='', help='path to load model.')
+ parser.add_argument('--image-path', default='', type=str, help='')
+ parser.add_argument('--result-dir', default='.', type=str, help='')
+ parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+ main(parser.parse_args())
diff --git a/src/face3d/models/arcface_torch/partial_fc.py b/src/face3d/models/arcface_torch/partial_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..17e2d25715d10ba446c957e1d2528b0687ed71d5
--- /dev/null
+++ b/src/face3d/models/arcface_torch/partial_fc.py
@@ -0,0 +1,222 @@
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from torch.nn import Module
+from torch.nn.functional import normalize, linear
+from torch.nn.parameter import Parameter
+
+
+class PartialFC(Module):
+ """
+ Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
+ Partial FC: Training 10 Million Identities on a Single Machine
+ See the original paper:
+ https://arxiv.org/abs/2010.05222
+ """
+
+ @torch.no_grad()
+ def __init__(self, rank, local_rank, world_size, batch_size, resume,
+ margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
+ """
+ rank: int
+ Unique process(GPU) ID from 0 to world_size - 1.
+ local_rank: int
+ Unique process(GPU) ID within the server from 0 to 7.
+ world_size: int
+ Number of GPU.
+ batch_size: int
+ Batch size on current rank(GPU).
+ resume: bool
+ Select whether to restore the weight of softmax.
+ margin_softmax: callable
+ A function of margin softmax, eg: cosface, arcface.
+ num_classes: int
+ The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
+ required.
+ sample_rate: float
+ The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
+ can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
+ embedding_size: int
+ The feature dimension, default is 512.
+ prefix: str
+ Path for save checkpoint, default is './'.
+ """
+ super(PartialFC, self).__init__()
+ #
+ self.num_classes: int = num_classes
+ self.rank: int = rank
+ self.local_rank: int = local_rank
+ self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
+ self.world_size: int = world_size
+ self.batch_size: int = batch_size
+ self.margin_softmax: callable = margin_softmax
+ self.sample_rate: float = sample_rate
+ self.embedding_size: int = embedding_size
+ self.prefix: str = prefix
+ self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
+ self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+
+ self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
+ self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
+
+ if resume:
+ try:
+ self.weight: torch.Tensor = torch.load(self.weight_name)
+ self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
+ if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
+ raise IndexError
+ logging.info("softmax weight resume successfully!")
+ logging.info("softmax weight mom resume successfully!")
+ except (FileNotFoundError, KeyError, IndexError):
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+ logging.info("softmax weight init!")
+ logging.info("softmax weight mom init!")
+ else:
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+ logging.info("softmax weight init successfully!")
+ logging.info("softmax weight mom init successfully!")
+ self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
+
+ self.index = None
+ if int(self.sample_rate) == 1:
+ self.update = lambda: 0
+ self.sub_weight = Parameter(self.weight)
+ self.sub_weight_mom = self.weight_mom
+ else:
+ self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
+
+ def save_params(self):
+ """ Save softmax weight for each rank on prefix
+ """
+ torch.save(self.weight.data, self.weight_name)
+ torch.save(self.weight_mom, self.weight_mom_name)
+
+ @torch.no_grad()
+ def sample(self, total_label):
+ """
+ Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
+ `num_sample`.
+
+ total_label: tensor
+ Label after all gather, which cross all GPUs.
+ """
+ index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
+ total_label[~index_positive] = -1
+ total_label[index_positive] -= self.class_start
+ if int(self.sample_rate) != 1:
+ positive = torch.unique(total_label[index_positive], sorted=True)
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local], device=self.device)
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1]
+ index = index.sort()[0]
+ else:
+ index = positive
+ self.index = index
+ total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
+ self.sub_weight = Parameter(self.weight[index])
+ self.sub_weight_mom = self.weight_mom[index]
+
+ def forward(self, total_features, norm_weight):
+ """ Partial fc forward, `logits = X * sample(W)`
+ """
+ torch.cuda.current_stream().wait_stream(self.stream)
+ logits = linear(total_features, norm_weight)
+ return logits
+
+ @torch.no_grad()
+ def update(self):
+ """ Set updated weight and weight_mom to memory bank.
+ """
+ self.weight_mom[self.index] = self.sub_weight_mom
+ self.weight[self.index] = self.sub_weight
+
+ def prepare(self, label, optimizer):
+ """
+ get sampled class centers for cal softmax.
+
+ label: tensor
+ Label tensor on each rank.
+ optimizer: opt
+ Optimizer for partial fc, which need to get weight mom.
+ """
+ with torch.cuda.stream(self.stream):
+ total_label = torch.zeros(
+ size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
+ dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
+ self.sample(total_label)
+ optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
+ optimizer.param_groups[-1]['params'][0] = self.sub_weight
+ optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
+ norm_weight = normalize(self.sub_weight)
+ return total_label, norm_weight
+
+ def forward_backward(self, label, features, optimizer):
+ """
+ Partial fc forward and backward with model parallel
+
+ label: tensor
+ Label tensor on each rank(GPU)
+ features: tensor
+ Features tensor on each rank(GPU)
+ optimizer: optimizer
+ Optimizer for partial fc
+
+ Returns:
+ --------
+ x_grad: tensor
+ The gradient of features.
+ loss_v: tensor
+ Loss value for cross entropy.
+ """
+ total_label, norm_weight = self.prepare(label, optimizer)
+ total_features = torch.zeros(
+ size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
+ dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
+ total_features.requires_grad = True
+
+ logits = self.forward(total_features, norm_weight)
+ logits = self.margin_softmax(logits, total_label)
+
+ with torch.no_grad():
+ max_fc = torch.max(logits, dim=1, keepdim=True)[0]
+ dist.all_reduce(max_fc, dist.ReduceOp.MAX)
+
+ # calculate exp(logits) and all-reduce
+ logits_exp = torch.exp(logits - max_fc)
+ logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
+ dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
+
+ # calculate prob
+ logits_exp.div_(logits_sum_exp)
+
+ # get one-hot
+ grad = logits_exp
+ index = torch.where(total_label != -1)[0]
+ one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
+ one_hot.scatter_(1, total_label[index, None], 1)
+
+ # calculate loss
+ loss = torch.zeros(grad.size()[0], 1, device=grad.device)
+ loss[index] = grad[index].gather(1, total_label[index, None])
+ dist.all_reduce(loss, dist.ReduceOp.SUM)
+ loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ # calculate grad
+ grad[index] -= one_hot
+ grad.div_(self.batch_size * self.world_size)
+
+ logits.backward(grad)
+ if total_features.grad is not None:
+ total_features.grad.detach_()
+ x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
+ # feature gradient all-reduce
+ dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
+ x_grad = x_grad * self.world_size
+ # backward backbone
+ return x_grad, loss_v
diff --git a/src/face3d/models/arcface_torch/requirement.txt b/src/face3d/models/arcface_torch/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f72c1b3ba814ae1e0bc1c1f56402026978b9e870
--- /dev/null
+++ b/src/face3d/models/arcface_torch/requirement.txt
@@ -0,0 +1,5 @@
+tensorboard
+easydict
+mxnet
+onnx
+sklearn
diff --git a/src/face3d/models/arcface_torch/run.sh b/src/face3d/models/arcface_torch/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..61af4b4950eb11334e55362e3e3c5e2796979a01
--- /dev/null
+++ b/src/face3d/models/arcface_torch/run.sh
@@ -0,0 +1,2 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
diff --git a/src/face3d/models/arcface_torch/torch2onnx.py b/src/face3d/models/arcface_torch/torch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc26ab82e552331bc8d75b34e81000418f4d38ec
--- /dev/null
+++ b/src/face3d/models/arcface_torch/torch2onnx.py
@@ -0,0 +1,59 @@
+import numpy as np
+import onnx
+import torch
+
+
+def convert_onnx(net, path_module, output, opset=11, simplify=False):
+ assert isinstance(net, torch.nn.Module)
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = img.astype(np.float)
+ img = (img / 255. - 0.5) / 0.5 # torch style norm
+ img = img.transpose((2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+
+ weight = torch.load(path_module)
+ net.load_state_dict(weight)
+ net.eval()
+ torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
+ model = onnx.load(output)
+ graph = model.graph
+ graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ if simplify:
+ from onnxsim import simplify
+ model, check = simplify(model)
+ assert check, "Simplified ONNX model could not be validated"
+ onnx.save(model, output)
+
+
+if __name__ == '__main__':
+ import os
+ import argparse
+ from backbones import get_model
+
+ parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
+ parser.add_argument('input', type=str, help='input backbone.pth file or path')
+ parser.add_argument('--output', type=str, default=None, help='output onnx path')
+ parser.add_argument('--network', type=str, default=None, help='backbone network')
+ parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
+ args = parser.parse_args()
+ input_file = args.input
+ if os.path.isdir(input_file):
+ input_file = os.path.join(input_file, "backbone.pth")
+ assert os.path.exists(input_file)
+ model_name = os.path.basename(os.path.dirname(input_file)).lower()
+ params = model_name.split("_")
+ if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
+ if args.network is None:
+ args.network = params[2]
+ assert args.network is not None
+ print(args)
+ backbone_onnx = get_model(args.network, dropout=0)
+
+ output_path = args.output
+ if output_path is None:
+ output_path = os.path.join(os.path.dirname(__file__), 'onnx')
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ assert os.path.isdir(output_path)
+ output_file = os.path.join(output_path, "%s.onnx" % model_name)
+ convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
diff --git a/src/face3d/models/arcface_torch/train.py b/src/face3d/models/arcface_torch/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..55eca2d0ad9463415970e09bccab8b722e496704
--- /dev/null
+++ b/src/face3d/models/arcface_torch/train.py
@@ -0,0 +1,141 @@
+import argparse
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torch.utils.data.distributed
+from torch.nn.utils import clip_grad_norm_
+
+import losses
+from backbones import get_model
+from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
+from partial_fc import PartialFC
+from utils.utils_amp import MaxClipGradScaler
+from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
+from utils.utils_config import get_config
+from utils.utils_logging import AverageMeter, init_logging
+
+
+def main(args):
+ cfg = get_config(args.config)
+ try:
+ world_size = int(os.environ['WORLD_SIZE'])
+ rank = int(os.environ['RANK'])
+ dist.init_process_group('nccl')
+ except KeyError:
+ world_size = 1
+ rank = 0
+ dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
+
+ local_rank = args.local_rank
+ torch.cuda.set_device(local_rank)
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ if cfg.rec == "synthetic":
+ train_set = SyntheticDataset(local_rank=local_rank)
+ else:
+ train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
+
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
+ train_loader = DataLoaderX(
+ local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
+ sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
+
+ if cfg.resume:
+ try:
+ backbone_pth = os.path.join(cfg.output, "backbone.pth")
+ backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
+ if rank == 0:
+ logging.info("backbone resume successfully!")
+ except (FileNotFoundError, KeyError, IndexError, RuntimeError):
+ if rank == 0:
+ logging.info("resume fail, backbone init successfully!")
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank])
+ backbone.train()
+ margin_softmax = losses.get_loss(cfg.loss)
+ module_partial_fc = PartialFC(
+ rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
+ batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
+ sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
+
+ opt_backbone = torch.optim.SGD(
+ params=[{'params': backbone.parameters()}],
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
+ momentum=0.9, weight_decay=cfg.weight_decay)
+ opt_pfc = torch.optim.SGD(
+ params=[{'params': module_partial_fc.parameters()}],
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
+ momentum=0.9, weight_decay=cfg.weight_decay)
+
+ num_image = len(train_set)
+ total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
+ cfg.total_step = num_image // total_batch_size * cfg.num_epoch
+
+ def lr_step_func(current_step):
+ cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
+ if current_step < cfg.warmup_step:
+ return current_step / cfg.warmup_step
+ else:
+ return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
+
+ scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
+ optimizer=opt_backbone, lr_lambda=lr_step_func)
+ scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
+ optimizer=opt_pfc, lr_lambda=lr_step_func)
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ val_target = cfg.val_targets
+ callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
+ callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
+ callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
+
+ loss = AverageMeter()
+ start_epoch = 0
+ global_step = 0
+ grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
+ for epoch in range(start_epoch, cfg.num_epoch):
+ train_sampler.set_epoch(epoch)
+ for step, (img, label) in enumerate(train_loader):
+ global_step += 1
+ features = F.normalize(backbone(img))
+ x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
+ if cfg.fp16:
+ features.backward(grad_amp.scale(x_grad))
+ grad_amp.unscale_(opt_backbone)
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+ grad_amp.step(opt_backbone)
+ grad_amp.update()
+ else:
+ features.backward(x_grad)
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+ opt_backbone.step()
+
+ opt_pfc.step()
+ module_partial_fc.update()
+ opt_backbone.zero_grad()
+ opt_pfc.zero_grad()
+ loss.update(loss_v, 1)
+ callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
+ callback_verification(global_step, backbone)
+ scheduler_backbone.step()
+ scheduler_pfc.step()
+ callback_checkpoint(global_step, backbone, module_partial_fc)
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+ parser.add_argument('config', type=str, help='py config file')
+ parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
+ main(parser.parse_args())
diff --git a/src/face3d/models/arcface_torch/utils/__init__.py b/src/face3d/models/arcface_torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/utils/plot.py b/src/face3d/models/arcface_torch/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccc588e5c01ca550b69c385aeb3fd139c59fb88a
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/plot.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+
+import os
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from sklearn.metrics import roc_curve, auc
+
+image_path = "/data/anxiang/IJB_release/IJBC"
+files = [
+ "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
+]
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % 'ijbc'))
+
+methods = []
+scores = []
+for file in files:
+ methods.append(file.split('/')[-2])
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+print(tpr_fpr_table)
diff --git a/src/face3d/models/arcface_torch/utils/utils_amp.py b/src/face3d/models/arcface_torch/utils/utils_amp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ac2a03f4212faa129faed447a8f4519c0a00a8b
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_amp.py
@@ -0,0 +1,88 @@
+from typing import Dict, List
+
+import torch
+
+if torch.__version__ < '1.9':
+ Iterable = torch._six.container_abcs.Iterable
+else:
+ import collections
+
+ Iterable = collections.abc.Iterable
+from torch.cuda.amp import GradScaler
+
+
+class _MultiDeviceReplicator(object):
+ """
+ Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
+ """
+
+ def __init__(self, master_tensor: torch.Tensor) -> None:
+ assert master_tensor.is_cuda
+ self.master = master_tensor
+ self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+ def get(self, device) -> torch.Tensor:
+ retval = self._per_device_tensors.get(device, None)
+ if retval is None:
+ retval = self.master.to(device=device, non_blocking=True, copy=True)
+ self._per_device_tensors[device] = retval
+ return retval
+
+
+class MaxClipGradScaler(GradScaler):
+ def __init__(self, init_scale, max_scale: float, growth_interval=100):
+ GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
+ self.max_scale = max_scale
+
+ def scale_clip(self):
+ if self.get_scale() == self.max_scale:
+ self.set_growth_factor(1)
+ elif self.get_scale() < self.max_scale:
+ self.set_growth_factor(2)
+ elif self.get_scale() > self.max_scale:
+ self._scale.fill_(self.max_scale)
+ self.set_growth_factor(1)
+
+ def scale(self, outputs):
+ """
+ Multiplies ('scales') a tensor or list of tensors by the scale factor.
+
+ Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
+ unmodified.
+
+ Arguments:
+ outputs (Tensor or iterable of Tensors): Outputs to scale.
+ """
+ if not self._enabled:
+ return outputs
+ self.scale_clip()
+ # Short-circuit for the common case.
+ if isinstance(outputs, torch.Tensor):
+ assert outputs.is_cuda
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(outputs.device)
+ assert self._scale is not None
+ return outputs * self._scale.to(device=outputs.device, non_blocking=True)
+
+ # Invoke the more complex machinery only if we're treating multiple outputs.
+ stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
+
+ def apply_scale(val):
+ if isinstance(val, torch.Tensor):
+ assert val.is_cuda
+ if len(stash) == 0:
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(val.device)
+ assert self._scale is not None
+ stash.append(_MultiDeviceReplicator(self._scale))
+ return val * stash[0].get(val.device)
+ elif isinstance(val, Iterable):
+ iterable = map(apply_scale, val)
+ if isinstance(val, list) or isinstance(val, tuple):
+ return type(val)(iterable)
+ else:
+ return iterable
+ else:
+ raise ValueError("outputs must be a Tensor or an iterable of Tensors")
+
+ return apply_scale(outputs)
diff --git a/src/face3d/models/arcface_torch/utils/utils_callbacks.py b/src/face3d/models/arcface_torch/utils/utils_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd2f56cba47c57de102710ff56eaac591e59f4da
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_callbacks.py
@@ -0,0 +1,117 @@
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+from eval import verification
+from utils.utils_logging import AverageMeter
+
+
+class CallBackVerification(object):
+ def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
+ self.frequent: int = frequent
+ self.rank: int = rank
+ self.highest_acc: float = 0.0
+ self.highest_acc_list: List[float] = [0.0] * len(val_targets)
+ self.ver_list: List[object] = []
+ self.ver_name_list: List[str] = []
+ if self.rank is 0:
+ self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
+
+ def ver_test(self, backbone: torch.nn.Module, global_step: int):
+ results = []
+ for i in range(len(self.ver_list)):
+ acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
+ self.ver_list[i], backbone, 10, 10)
+ logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
+ logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
+ if acc2 > self.highest_acc_list[i]:
+ self.highest_acc_list[i] = acc2
+ logging.info(
+ '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
+ results.append(acc2)
+
+ def init_dataset(self, val_targets, data_dir, image_size):
+ for name in val_targets:
+ path = os.path.join(data_dir, name + ".bin")
+ if os.path.exists(path):
+ data_set = verification.load_bin(path, image_size)
+ self.ver_list.append(data_set)
+ self.ver_name_list.append(name)
+
+ def __call__(self, num_update, backbone: torch.nn.Module):
+ if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
+ backbone.eval()
+ self.ver_test(backbone, num_update)
+ backbone.train()
+
+
+class CallBackLogging(object):
+ def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
+ self.frequent: int = frequent
+ self.rank: int = rank
+ self.time_start = time.time()
+ self.total_step: int = total_step
+ self.batch_size: int = batch_size
+ self.world_size: int = world_size
+ self.writer = writer
+
+ self.init = False
+ self.tic = 0
+
+ def __call__(self,
+ global_step: int,
+ loss: AverageMeter,
+ epoch: int,
+ fp16: bool,
+ learning_rate: float,
+ grad_scaler: torch.cuda.amp.GradScaler):
+ if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
+ if self.init:
+ try:
+ speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
+ speed_total = speed * self.world_size
+ except ZeroDivisionError:
+ speed_total = float('inf')
+
+ time_now = (time.time() - self.time_start) / 3600
+ time_total = time_now / ((global_step + 1) / self.total_step)
+ time_for_end = time_total - time_now
+ if self.writer is not None:
+ self.writer.add_scalar('time_for_end', time_for_end, global_step)
+ self.writer.add_scalar('learning_rate', learning_rate, global_step)
+ self.writer.add_scalar('loss', loss.avg, global_step)
+ if fp16:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
+ "Fp16 Grad Scale: %2.f Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step,
+ grad_scaler.get_scale(), time_for_end
+ )
+ else:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
+ "Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
+ )
+ logging.info(msg)
+ loss.reset()
+ self.tic = time.time()
+ else:
+ self.init = True
+ self.tic = time.time()
+
+
+class CallBackModelCheckpoint(object):
+ def __init__(self, rank, output="./"):
+ self.rank: int = rank
+ self.output: str = output
+
+ def __call__(self, global_step, backbone, partial_fc, ):
+ if global_step > 100 and self.rank == 0:
+ path_module = os.path.join(self.output, "backbone.pth")
+ torch.save(backbone.module.state_dict(), path_module)
+ logging.info("Pytorch Model Saved in '{}'".format(path_module))
+
+ if global_step > 100 and partial_fc is not None:
+ partial_fc.save_params()
diff --git a/src/face3d/models/arcface_torch/utils/utils_config.py b/src/face3d/models/arcface_torch/utils/utils_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c02eaf70fc0140aca7925f621c29a496f491cae
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_config.py
@@ -0,0 +1,16 @@
+import importlib
+import os.path as osp
+
+
+def get_config(config_file):
+ assert config_file.startswith('configs/'), 'config file setting must start with configs/'
+ temp_config_name = osp.basename(config_file)
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ config = importlib.import_module("configs.base")
+ cfg = config.config
+ config = importlib.import_module("configs.%s" % temp_module_name)
+ job_cfg = config.config
+ cfg.update(job_cfg)
+ if cfg.output is None:
+ cfg.output = osp.join('work_dirs', temp_module_name)
+ return cfg
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/utils/utils_logging.py b/src/face3d/models/arcface_torch/utils/utils_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..c787b6aae7cd037a4718df44d672b8ffa9e5c249
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_logging.py
@@ -0,0 +1,41 @@
+import logging
+import os
+import sys
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value
+ """
+
+ def __init__(self):
+ self.val = None
+ self.avg = None
+ self.sum = None
+ self.count = None
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def init_logging(rank, models_root):
+ if rank == 0:
+ log_root = logging.getLogger()
+ log_root.setLevel(logging.INFO)
+ formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
+ handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
+ handler_stream = logging.StreamHandler(sys.stdout)
+ handler_file.setFormatter(formatter)
+ handler_stream.setFormatter(formatter)
+ log_root.addHandler(handler_file)
+ log_root.addHandler(handler_stream)
+ log_root.info('rank_id: %d' % rank)
diff --git a/src/face3d/models/arcface_torch/utils/utils_os.py b/src/face3d/models/arcface_torch/utils/utils_os.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/base_model.py b/src/face3d/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfe64a7f739ad8f8cfbf3073a2bf49e1468127fd
--- /dev/null
+++ b/src/face3d/models/base_model.py
@@ -0,0 +1,316 @@
+"""This script defines the base network model for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this fucntion, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): specify the images that you want to display and save.
+ -- self.visual_names (str list): define networks used in our training.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.isTrain = False
+ self.device = torch.device('cpu')
+ self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.parallel_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def dict_grad_hook_factory(add_func=lambda x: x):
+ saved_dict = dict()
+
+ def hook_gen(name):
+ def grad_hook(grad):
+ saved_vals = add_func(grad)
+ saved_dict[name] = saved_vals
+ return grad_hook
+ return hook_gen, saved_dict
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+
+ if not self.isTrain or opt.continue_train:
+ load_suffix = opt.epoch
+ self.load_networks(load_suffix)
+
+
+ # self.print_networks(opt.verbose)
+
+ def parallelize(self, convert_sync_batchnorm=True):
+ if not self.opt.use_ddp:
+ for name in self.parallel_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+ else:
+ for name in self.model_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ if convert_sync_batchnorm:
+ module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
+ setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
+ device_ids=[self.device.index],
+ find_unused_parameters=True, broadcast_buffers=True))
+
+ # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
+ for name in self.parallel_names:
+ if isinstance(name, str) and name not in self.model_names:
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+
+ # put state_dict of optimizer to gpu device
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ for optim in self.optimizers:
+ for state in optim.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(self.device)
+
+ def data_dependent_initialize(self, data):
+ pass
+
+ def train(self):
+ """Make models train mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.train()
+
+ def eval(self):
+ """Make models eval mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self, name='A'):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths if name =='A' else self.image_paths_B
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)[:, :3, ...]
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if not os.path.isdir(self.save_dir):
+ os.makedirs(self.save_dir)
+
+ save_filename = 'epoch_%s.pth' % (epoch)
+ save_path = os.path.join(self.save_dir, save_filename)
+
+ save_dict = {}
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel) or isinstance(net,
+ torch.nn.parallel.DistributedDataParallel):
+ net = net.module
+ save_dict[name] = net.state_dict()
+
+
+ for i, optim in enumerate(self.optimizers):
+ save_dict['opt_%02d'%i] = optim.state_dict()
+
+ for i, sched in enumerate(self.schedulers):
+ save_dict['sched_%02d'%i] = sched.state_dict()
+
+ torch.save(save_dict, save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
+ else:
+ load_dir = self.save_dir
+ load_filename = 'epoch_%s.pth' % (epoch)
+ load_path = os.path.join(load_dir, load_filename)
+ state_dict = torch.load(load_path, map_location=self.device)
+ print('loading the model from %s' % load_path)
+
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ net.load_state_dict(state_dict[name])
+
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ print('loading the optim from %s' % load_path)
+ for i, optim in enumerate(self.optimizers):
+ optim.load_state_dict(state_dict['opt_%02d'%i])
+
+ try:
+ print('loading the sched from %s' % load_path)
+ for i, sched in enumerate(self.schedulers):
+ sched.load_state_dict(state_dict['sched_%02d'%i])
+ except:
+ print('Failed to load schedulers, set schedulers according to epoch count manually')
+ for i, sched in enumerate(self.schedulers):
+ sched.last_epoch = self.opt.epoch_count - 1
+
+
+
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+ def generate_visuals_for_evaluation(self, data, mode):
+ return {}
diff --git a/src/face3d/models/bfm.py b/src/face3d/models/bfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a75db682f02dd1979d4a7de1d11dd3aa5cdf5279
--- /dev/null
+++ b/src/face3d/models/bfm.py
@@ -0,0 +1,331 @@
+"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+from src.face3d.util.load_mats import transferBFM09
+import os
+
+def perspective_projection(focal, center):
+ # return p.T (N, 3) @ (3, 3)
+ return np.array([
+ focal, 0, center,
+ 0, focal, center,
+ 0, 0, 1
+ ]).reshape([3, 3]).astype(np.float32).transpose()
+
+class SH:
+ def __init__(self):
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
+ self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
+
+
+
+class ParametricFaceModel:
+ def __init__(self,
+ bfm_folder='./BFM',
+ recenter=True,
+ camera_distance=10.,
+ init_lit=np.array([
+ 0.8, 0, 0, 0, 0, 0, 0, 0, 0
+ ]),
+ focal=1015.,
+ center=112.,
+ is_train=True,
+ default_name='BFM_model_front.mat'):
+
+ if not os.path.isfile(os.path.join(bfm_folder, default_name)):
+ transferBFM09(bfm_folder)
+
+ model = loadmat(os.path.join(bfm_folder, default_name))
+ # mean face shape. [3*N,1]
+ self.mean_shape = model['meanshape'].astype(np.float32)
+ # identity basis. [3*N,80]
+ self.id_base = model['idBase'].astype(np.float32)
+ # expression basis. [3*N,64]
+ self.exp_base = model['exBase'].astype(np.float32)
+ # mean face texture. [3*N,1] (0-255)
+ self.mean_tex = model['meantex'].astype(np.float32)
+ # texture basis. [3*N,80]
+ self.tex_base = model['texBase'].astype(np.float32)
+ # face indices for each vertex that lies in. starts from 0. [N,8]
+ self.point_buf = model['point_buf'].astype(np.int64) - 1
+ # vertex indices for each face. starts from 0. [F,3]
+ self.face_buf = model['tri'].astype(np.int64) - 1
+ # vertex indices for 68 landmarks. starts from 0. [68,1]
+ self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
+
+ if is_train:
+ # vertex indices for small face region to compute photometric error. starts from 0.
+ self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
+ # vertex indices for each face from small face region. starts from 0. [f,3]
+ self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
+ # vertex indices for pre-defined skin region to compute reflectance loss
+ self.skin_mask = np.squeeze(model['skinmask'])
+
+ if recenter:
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+
+ self.persc_proj = perspective_projection(focal, center)
+ self.device = 'cpu'
+ self.camera_distance = camera_distance
+ self.SH = SH()
+ self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
+
+
+ def to(self, device):
+ self.device = device
+ for key, value in self.__dict__.items():
+ if type(value).__module__ == np.__name__:
+ setattr(self, key, torch.tensor(value).to(device))
+
+
+ def compute_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
+ exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+
+ def compute_texture(self, tex_coeff, normalize=True):
+ """
+ Return:
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
+
+ Parameters:
+ tex_coeff -- torch.tensor, size (B, 80)
+ """
+ batch_size = tex_coeff.shape[0]
+ face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
+ if normalize:
+ face_texture = face_texture / 255.
+ return face_texture.reshape([batch_size, -1, 3])
+
+
+ def compute_norm(self, face_shape):
+ """
+ Return:
+ vertex_norm -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+
+ v1 = face_shape[:, self.face_buf[:, 0]]
+ v2 = face_shape[:, self.face_buf[:, 1]]
+ v3 = face_shape[:, self.face_buf[:, 2]]
+ e1 = v1 - v2
+ e2 = v2 - v3
+ face_norm = torch.cross(e1, e2, dim=-1)
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
+
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
+ return vertex_norm
+
+
+ def compute_color(self, face_texture, face_norm, gamma):
+ """
+ Return:
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
+
+ Parameters:
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
+ gamma -- torch.tensor, size (B, 27), SH coeffs
+ """
+ batch_size = gamma.shape[0]
+ v_num = face_texture.shape[1]
+ a, c = self.SH.a, self.SH.c
+ gamma = gamma.reshape([batch_size, 3, 9])
+ gamma = gamma + self.init_lit
+ gamma = gamma.permute(0, 2, 1)
+ Y = torch.cat([
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
+ -a[1] * c[1] * face_norm[..., 1:2],
+ a[1] * c[1] * face_norm[..., 2:],
+ -a[1] * c[1] * face_norm[..., :1],
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
+ ], dim=-1)
+ r = Y @ gamma[..., :1]
+ g = Y @ gamma[..., 1:2]
+ b = Y @ gamma[..., 2:]
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
+ return face_color
+
+
+ def compute_rotation(self, angles):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ angles -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = angles.shape[0]
+ ones = torch.ones([batch_size, 1]).to(self.device)
+ zeros = torch.zeros([batch_size, 1]).to(self.device)
+ x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
+
+ rot_x = torch.cat([
+ ones, zeros, zeros,
+ zeros, torch.cos(x), -torch.sin(x),
+ zeros, torch.sin(x), torch.cos(x)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat([
+ torch.cos(y), zeros, torch.sin(y),
+ zeros, ones, zeros,
+ -torch.sin(y), zeros, torch.cos(y)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat([
+ torch.cos(z), -torch.sin(z), zeros,
+ torch.sin(z), torch.cos(z), zeros,
+ zeros, zeros, ones
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+
+ def to_camera(self, face_shape):
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
+ return face_shape
+
+ def to_image(self, face_shape):
+ """
+ Return:
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+ # to image_plane
+ face_proj = face_shape @ self.persc_proj
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
+
+ return face_proj
+
+
+ def transform(self, face_shape, rot, trans):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ rot -- torch.tensor, size (B, 3, 3)
+ trans -- torch.tensor, size (B, 3)
+ """
+ return face_shape @ rot + trans.unsqueeze(1)
+
+
+ def get_landmarks(self, face_proj):
+ """
+ Return:
+ face_lms -- torch.tensor, size (B, 68, 2)
+
+ Parameters:
+ face_proj -- torch.tensor, size (B, N, 2)
+ """
+ return face_proj[:, self.keypoints]
+
+ def split_coeff(self, coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+ def compute_for_render(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ rotation = self.compute_rotation(coef_dict['angle'])
+
+
+ face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+ def compute_for_render_woRotation(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ #rotation = self.compute_rotation(coef_dict['angle'])
+
+
+ #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm # @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+
+if __name__ == '__main__':
+ transferBFM09()
\ No newline at end of file
diff --git a/src/face3d/models/facerecon_model.py b/src/face3d/models/facerecon_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7de8ca6eebc50ff1ed52c5ba37d31b43f977b5e1
--- /dev/null
+++ b/src/face3d/models/facerecon_model.py
@@ -0,0 +1,220 @@
+"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+from src.face3d.models.base_model import BaseModel
+from src.face3d.models import networks
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
+from src.face3d.util import util
+from src.face3d.util.nvdiffrast import MeshRenderer
+# from src.face3d.util.preprocess import estimate_norm_torch
+
+import trimesh
+from scipy.io import savemat
+
+class FaceReconModel(BaseModel):
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train=False):
+ """ Configures options specific for CUT model
+ """
+ # net structure and parameters
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
+ parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth')
+ parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
+ parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
+
+ # renderer parameters
+ parser.add_argument('--focal', type=float, default=1015.)
+ parser.add_argument('--center', type=float, default=112.)
+ parser.add_argument('--camera_d', type=float, default=10.)
+ parser.add_argument('--z_near', type=float, default=5.)
+ parser.add_argument('--z_far', type=float, default=15.)
+
+ if is_train:
+ # training parameters
+ parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
+ parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
+ parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
+ parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
+
+
+ # augmentation parameters
+ parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
+ parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
+ parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
+
+ # loss weights
+ parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
+ parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
+ parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
+ parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
+ parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
+ parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
+ parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
+ parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
+ parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
+
+ opt, _ = parser.parse_known_args()
+ parser.set_defaults(
+ focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
+ )
+ if is_train:
+ parser.set_defaults(
+ use_crop_face=True, use_predef_M=False
+ )
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+
+ self.visual_names = ['output_vis']
+ self.model_names = ['net_recon']
+ self.parallel_names = self.model_names + ['renderer']
+
+ self.facemodel = ParametricFaceModel(
+ bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
+ is_train=self.isTrain, default_name=opt.bfm_model
+ )
+
+ fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
+ self.renderer = MeshRenderer(
+ rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)
+ )
+
+ if self.isTrain:
+ self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
+
+ self.net_recog = networks.define_net_recog(
+ net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
+ )
+ # loss func name: (compute_%s_loss) % loss_name
+ self.compute_feat_loss = perceptual_loss
+ self.comupte_color_loss = photo_loss
+ self.compute_lm_loss = landmark_loss
+ self.compute_reg_loss = reg_loss
+ self.compute_reflc_loss = reflectance_loss
+
+ self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
+ self.optimizers = [self.optimizer]
+ self.parallel_names += ['net_recog']
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ self.input_img = input['imgs'].to(self.device)
+ self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
+ self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
+ self.trans_m = input['M'].to(self.device) if 'M' in input else None
+ self.image_paths = input['im_paths'] if 'im_paths' in input else None
+
+ def forward(self, output_coeff, device):
+ self.facemodel.to(device)
+ self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
+ self.facemodel.compute_for_render(output_coeff)
+ self.pred_mask, _, self.pred_face = self.renderer(
+ self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
+
+ self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
+
+
+ def compute_losses(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+
+ assert self.net_recog.training == False
+ trans_m = self.trans_m
+ if not self.opt.use_predef_M:
+ trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
+
+ pred_feat = self.net_recog(self.pred_face, trans_m)
+ gt_feat = self.net_recog(self.input_img, self.trans_m)
+ self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
+
+ face_mask = self.pred_mask
+ if self.opt.use_crop_face:
+ face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
+
+ face_mask = face_mask.detach()
+ self.loss_color = self.opt.w_color * self.comupte_color_loss(
+ self.pred_face, self.input_img, self.atten_mask * face_mask)
+
+ loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
+ self.loss_reg = self.opt.w_reg * loss_reg
+ self.loss_gamma = self.opt.w_gamma * loss_gamma
+
+ self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
+
+ self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
+
+ self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
+ + self.loss_lm + self.loss_reflc
+
+
+ def optimize_parameters(self, isTrain=True):
+ self.forward()
+ self.compute_losses()
+ """Update network weights; it will be called in every training iteration."""
+ if isTrain:
+ self.optimizer.zero_grad()
+ self.loss_all.backward()
+ self.optimizer.step()
+
+ def compute_visuals(self):
+ with torch.no_grad():
+ input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
+ output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
+ output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
+
+ if self.gt_lm is not None:
+ gt_lm_numpy = self.gt_lm.cpu().numpy()
+ pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
+
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw, output_vis_numpy), axis=-2)
+ else:
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw), axis=-2)
+
+ self.output_vis = torch.tensor(
+ output_vis_numpy / 255., dtype=torch.float32
+ ).permute(0, 3, 1, 2).to(self.device)
+
+ def save_mesh(self, name):
+
+ recon_shape = self.pred_vertex # get reconstructed shape
+ recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
+ recon_shape = recon_shape.cpu().numpy()[0]
+ recon_color = self.pred_color
+ recon_color = recon_color.cpu().numpy()[0]
+ tri = self.facemodel.face_buf.cpu().numpy()
+ mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))
+ mesh.export(name)
+
+ def save_coeff(self,name):
+
+ pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
+ pred_lm = self.pred_lm.cpu().numpy()
+ pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
+ pred_coeffs['lm68'] = pred_lm
+ savemat(name,pred_coeffs)
+
+
+
diff --git a/src/face3d/models/losses.py b/src/face3d/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..09d6a85870af1ef2b857e4a3fdd4b2f7fc991317
--- /dev/null
+++ b/src/face3d/models/losses.py
@@ -0,0 +1,113 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from kornia.geometry import warp_affine
+import torch.nn.functional as F
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+### perceptual level loss
+class PerceptualLoss(nn.Module):
+ def __init__(self, recog_net, input_size=112):
+ super(PerceptualLoss, self).__init__()
+ self.recog_net = recog_net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+ def forward(imageA, imageB, M):
+ """
+ 1 - cosine distance
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
+ imageB --same as imageA
+ """
+
+ imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
+ imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
+
+ # freeze bn
+ self.recog_net.eval()
+
+ id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
+ id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+def perceptual_loss(id_featureA, id_featureB):
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+### image level loss
+def photo_loss(imageA, imageB, mask, eps=1e-6):
+ """
+ l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
+ imageB --same as imageA
+ """
+ loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
+ loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
+ return loss
+
+def landmark_loss(predict_lm, gt_lm, weight=None):
+ """
+ weighted mse loss
+ Parameters:
+ predict_lm --torch.tensor (B, 68, 2)
+ gt_lm --torch.tensor (B, 68, 2)
+ weight --numpy.array (1, 68)
+ """
+ if not weight:
+ weight = np.ones([68])
+ weight[28:31] = 20
+ weight[-8:] = 20
+ weight = np.expand_dims(weight, 0)
+ weight = torch.tensor(weight).to(predict_lm.device)
+ loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
+ loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
+ return loss
+
+
+### regulization
+def reg_loss(coeffs_dict, opt=None):
+ """
+ l2 norm without the sqrt, from yu's implementation (mse)
+ tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
+ Parameters:
+ coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
+
+ """
+ # coefficient regularization to ensure plausible 3d faces
+ if opt:
+ w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
+ else:
+ w_id, w_exp, w_tex = 1, 1, 1, 1
+ creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
+ w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
+ w_tex * torch.sum(coeffs_dict['tex'] ** 2)
+ creg_loss = creg_loss / coeffs_dict['id'].shape[0]
+
+ # gamma regularization to ensure a nearly-monochromatic light
+ gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
+ gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
+ gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
+
+ return creg_loss, gamma_loss
+
+def reflectance_loss(texture, mask):
+ """
+ minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
+ Parameters:
+ texture --torch.tensor, (B, N, 3)
+ mask --torch.tensor, (N), 1 or 0
+
+ """
+ mask = mask.reshape([1, mask.shape[0], 1])
+ texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
+ loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
+ return loss
+
diff --git a/src/face3d/models/networks.py b/src/face3d/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..ead9cdcb8720b845c233de79dc8a8d1668492108
--- /dev/null
+++ b/src/face3d/models/networks.py
@@ -0,0 +1,521 @@
+"""This script defines deep neural networks for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch.nn.functional as F
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+import torch
+from torch import Tensor
+import torch.nn as nn
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+from typing import Type, Any, Callable, Union, List, Optional
+from .arcface_torch.backbones import get_model
+from kornia.geometry import warp_affine
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+def filter_state_dict(state_dict, remove_name='fc'):
+ new_state_dict = {}
+ for key in state_dict:
+ if remove_name in key:
+ continue
+ new_state_dict[key] = state_dict[key]
+ return new_state_dict
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def define_net_recon(net_recon, use_last_fc=False, init_path=None):
+ return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
+
+def define_net_recog(net_recog, pretrained_path=None):
+ net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
+ net.eval()
+ return net
+
+class ReconNetWrapper(nn.Module):
+ fc_dim=257
+ def __init__(self, net_recon, use_last_fc=False, init_path=None):
+ super(ReconNetWrapper, self).__init__()
+ self.use_last_fc = use_last_fc
+ if net_recon not in func_dict:
+ return NotImplementedError('network [%s] is not implemented', net_recon)
+ func, last_dim = func_dict[net_recon]
+ backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
+ if init_path and os.path.isfile(init_path):
+ state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
+ backbone.load_state_dict(state_dict)
+ print("loading init net_recon %s from %s" %(net_recon, init_path))
+ self.backbone = backbone
+ if not use_last_fc:
+ self.final_layers = nn.ModuleList([
+ conv1x1(last_dim, 80, bias=True), # id layer
+ conv1x1(last_dim, 64, bias=True), # exp layer
+ conv1x1(last_dim, 80, bias=True), # tex layer
+ conv1x1(last_dim, 3, bias=True), # angle layer
+ conv1x1(last_dim, 27, bias=True), # gamma layer
+ conv1x1(last_dim, 2, bias=True), # tx, ty
+ conv1x1(last_dim, 1, bias=True) # tz
+ ])
+ for m in self.final_layers:
+ nn.init.constant_(m.weight, 0.)
+ nn.init.constant_(m.bias, 0.)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ if not self.use_last_fc:
+ output = []
+ for layer in self.final_layers:
+ output.append(layer(x))
+ x = torch.flatten(torch.cat(output, dim=1), 1)
+ return x
+
+
+class RecogNetWrapper(nn.Module):
+ def __init__(self, net_recog, pretrained_path=None, input_size=112):
+ super(RecogNetWrapper, self).__init__()
+ net = get_model(name=net_recog, fp16=False)
+ if pretrained_path:
+ state_dict = torch.load(pretrained_path, map_location='cpu')
+ net.load_state_dict(state_dict)
+ print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path))
+ for param in net.parameters():
+ param.requires_grad = False
+ self.net = net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+
+ def forward(self, image, M):
+ image = self.preprocess(resize_n_crop(image, M, self.input_size))
+ id_feature = F.normalize(self.net(image), dim=-1, p=2)
+ return id_feature
+
+
+# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ use_last_fc: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.use_last_fc = use_last_fc
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ if self.use_last_fc:
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ if self.use_last_fc:
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def _resnet(
+ arch: str,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ pretrained: bool,
+ progress: bool,
+ **kwargs: Any
+) -> ResNet:
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+func_dict = {
+ 'resnet18': (resnet18, 512),
+ 'resnet50': (resnet50, 2048)
+}
diff --git a/src/face3d/models/template_model.py b/src/face3d/models/template_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dac7b33d5889777eb63c9882a3b9fa094dcab293
--- /dev/null
+++ b/src/face3d/models/template_model.py
@@ -0,0 +1,100 @@
+"""Model class template
+
+This module provides a template for users to implement custom models.
+You can specify '--model template' to use this model.
+The class name should be consistent with both the filename and its model option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+It implements a simple image-to-image translation baseline based on regression loss.
+Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
+ min_ ||netG(data_A) - data_B||_1
+You need to implement the following functions:
+ : Add model-specific options and rewrite default values for existing options.
+ <__init__>: Initialize this model class.
+ : Unpack input data and perform data pre-processing.
+ : Run forward pass. This will be called by both and .
+ : Update network weights; it will be called in every training iteration.
+"""
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class TemplateModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new model-specific options and rewrite default values for existing options.
+
+ Parameters:
+ parser -- the option parser
+ is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
+ if is_train:
+ parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
+ self.loss_names = ['loss_G']
+ # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
+ self.visual_names = ['data_A', 'data_B', 'output']
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
+ # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
+ self.model_names = ['G']
+ # define networks; you can use opt.isTrain to specify different behaviors for training and test.
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
+ if self.isTrain: # only defined during training time
+ # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
+ # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
+ self.criterionLoss = torch.nn.L1Loss()
+ # define and initialize optimizers. You can define one optimizer for each network.
+ # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = [self.optimizer]
+
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
+ self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
+ self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
+ self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
+
+ def forward(self):
+ """Run forward pass. This will be called by both functions and ."""
+ self.output = self.netG(self.data_A) # generate output image given the input data_A
+
+ def backward(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # caculate the intermediate results if necessary; here self.output has been computed during function
+ # calculate loss given the input and intermediate results
+ self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
+ self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
+
+ def optimize_parameters(self):
+ """Update network weights; it will be called in every training iteration."""
+ self.forward() # first call forward to calculate intermediate results
+ self.optimizer.zero_grad() # clear network G's existing gradients
+ self.backward() # calculate gradients for network G
+ self.optimizer.step() # update gradients for network G
diff --git a/src/face3d/options/__init__.py b/src/face3d/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90
--- /dev/null
+++ b/src/face3d/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/src/face3d/options/base_options.py b/src/face3d/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8f921d5a43434ae802a55a0fa3889c4b7ab9f6d
--- /dev/null
+++ b/src/face3d/options/base_options.py
@@ -0,0 +1,169 @@
+"""This script contains base options for Deep3DFaceRecon_pytorch
+"""
+
+import argparse
+import os
+from util import util
+import numpy as np
+import torch
+import face3d.models as models
+import face3d.data as data
+
+
+class BaseOptions():
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self, cmd_line=None):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+ self.cmd_line = None
+ if cmd_line is not None:
+ self.cmd_line = cmd_line.split()
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+ parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
+ parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
+ parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
+ parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
+ parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
+ parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
+ parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
+
+ # model parameters
+ parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
+
+ # additional parameters
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args()
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line)
+
+ # set cuda visible devices
+ os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
+
+ # modify dataset-related parser options
+ if opt.dataset_mode:
+ dataset_name = opt.dataset_mode
+ dataset_option_setter = data.get_option_setter(dataset_name)
+ parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ if self.cmd_line is None:
+ return parser.parse_args()
+ else:
+ return parser.parse_args(self.cmd_line)
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ try:
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+ except PermissionError as error:
+ print("permission error {}".format(error))
+ pass
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ gpu_ids.append(id)
+ opt.world_size = len(gpu_ids)
+ # if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(gpu_ids[0])
+ if opt.world_size == 1:
+ opt.use_ddp = False
+
+ if opt.phase != 'test':
+ # set continue_train automatically
+ if opt.pretrained_name is None:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ else:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
+ if os.path.isdir(model_dir):
+ model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
+ if os.path.isdir(model_dir) and len(model_pths) != 0:
+ opt.continue_train= True
+
+ # update the latest epoch count
+ if opt.continue_train:
+ if opt.epoch == 'latest':
+ epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
+ if len(epoch_counts) != 0:
+ opt.epoch_count = max(epoch_counts) + 1
+ else:
+ opt.epoch_count = int(opt.epoch) + 1
+
+
+ self.print_options(opt)
+ self.opt = opt
+ return self.opt
diff --git a/src/face3d/options/inference_options.py b/src/face3d/options/inference_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..c453965959ab4cfb31acbc424f994db68c3d4df5
--- /dev/null
+++ b/src/face3d/options/inference_options.py
@@ -0,0 +1,23 @@
+from face3d.options.base_options import BaseOptions
+
+
+class InferenceOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')
+ parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')
+ parser.add_argument('--save_split_files', action='store_true', help='save split files or not')
+ parser.add_argument('--inference_batch_size', type=int, default=8)
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/src/face3d/options/test_options.py b/src/face3d/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff3ad142779850d1d5a1640bc00f70d34d4a862
--- /dev/null
+++ b/src/face3d/options/test_options.py
@@ -0,0 +1,21 @@
+"""This script contains the test options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/src/face3d/options/train_options.py b/src/face3d/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..1337bfdd5f372b5c686a91b394a2aadbe5741f44
--- /dev/null
+++ b/src/face3d/options/train_options.py
@@ -0,0 +1,53 @@
+"""This script contains the training options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+from util import util
+
+class TrainOptions(BaseOptions):
+ """This class includes training options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser)
+ # dataset parameters
+ # for train
+ parser.add_argument('--data_root', type=str, default='./', help='dataset root')
+ parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
+ parser.add_argument('--batch_size', type=int, default=32)
+ parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
+ parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
+
+ # for val
+ parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
+ parser.add_argument('--batch_size_val', type=int, default=32)
+
+
+ # visualization parameters
+ parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+
+ # network saving and loading parameters
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
+ parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
+ parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+ parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
+
+ # training parameters
+ parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
+ parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
+ parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
+ parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
+
+ self.isTrain = True
+ return parser
diff --git a/src/face3d/util/BBRegressorParam_r.mat b/src/face3d/util/BBRegressorParam_r.mat
new file mode 100644
index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084
Binary files /dev/null and b/src/face3d/util/BBRegressorParam_r.mat differ
diff --git a/src/face3d/util/__init__.py b/src/face3d/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..04eecb58b62f8c9d11d17606c6241d278a48b9b9
--- /dev/null
+++ b/src/face3d/util/__init__.py
@@ -0,0 +1,3 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
+from src.face3d.util import *
+
diff --git a/src/face3d/util/detect_lm68.py b/src/face3d/util/detect_lm68.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7e40997289e17405e1fb6c408d21adce7b626ce
--- /dev/null
+++ b/src/face3d/util/detect_lm68.py
@@ -0,0 +1,106 @@
+import os
+import cv2
+import numpy as np
+from scipy.io import loadmat
+import tensorflow as tf
+from util.preprocess import align_for_lm
+from shutil import move
+
+mean_face = np.loadtxt('util/test_mean_face.txt')
+mean_face = mean_face.reshape([68, 2])
+
+def save_label(labels, save_path):
+ np.savetxt(save_path, labels)
+
+def draw_landmarks(img, landmark, save_name):
+ landmark = landmark
+ lm_img = np.zeros([img.shape[0], img.shape[1], 3])
+ lm_img[:] = img.astype(np.float32)
+ landmark = np.round(landmark).astype(np.int32)
+
+ for i in range(len(landmark)):
+ for j in range(-1, 1):
+ for k in range(-1, 1):
+ if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
+ img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
+ landmark[i, 0]+k > 0 and \
+ landmark[i, 0]+k < img.shape[1]:
+ lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
+ :] = np.array([0, 0, 255])
+ lm_img = lm_img.astype(np.uint8)
+
+ cv2.imwrite(save_name, lm_img)
+
+
+def load_data(img_name, txt_name):
+ return cv2.imread(img_name), np.loadtxt(txt_name)
+
+# create tensorflow graph for landmark detector
+def load_lm_graph(graph_filename):
+ with tf.gfile.GFile(graph_filename, 'rb') as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+
+ with tf.Graph().as_default() as graph:
+ tf.import_graph_def(graph_def, name='net')
+ img_224 = graph.get_tensor_by_name('net/input_imgs:0')
+ output_lm = graph.get_tensor_by_name('net/lm:0')
+ lm_sess = tf.Session(graph=graph)
+
+ return lm_sess,img_224,output_lm
+
+# landmark detection
+def detect_68p(img_path,sess,input_op,output_op):
+ print('detecting landmarks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ vis_path = os.path.join(img_path, 'vis')
+ remove_path = os.path.join(img_path, 'remove')
+ save_path = os.path.join(img_path, 'landmarks')
+ if not os.path.isdir(vis_path):
+ os.makedirs(vis_path)
+ if not os.path.isdir(remove_path):
+ os.makedirs(remove_path)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
+ full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
+
+ # if an image does not have detected 5 facial landmarks, remove it from the training list
+ if not os.path.isfile(full_txt_name):
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # load data
+ img, five_points = load_data(full_image_name, full_txt_name)
+ input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
+
+ # if the alignment fails, remove corresponding image from the training list
+ if scale == 0:
+ move(full_txt_name, os.path.join(
+ remove_path, txt_name))
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # detect landmarks
+ input_img = np.reshape(
+ input_img, [1, 224, 224, 3]).astype(np.float32)
+ landmark = sess.run(
+ output_op, feed_dict={input_op: input_img})
+
+ # transform back to original image coordinate
+ landmark = landmark.reshape([68, 2]) + mean_face
+ landmark[:, 1] = 223 - landmark[:, 1]
+ landmark = landmark / scale
+ landmark[:, 0] = landmark[:, 0] + bbox[0]
+ landmark[:, 1] = landmark[:, 1] + bbox[1]
+ landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
+
+ if i % 100 == 0:
+ draw_landmarks(img, landmark, os.path.join(vis_path, name))
+ save_label(landmark, os.path.join(save_path, txt_name))
diff --git a/src/face3d/util/generate_list.py b/src/face3d/util/generate_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..943d906781063c3584a7e5b5c784f8aac0694985
--- /dev/null
+++ b/src/face3d/util/generate_list.py
@@ -0,0 +1,34 @@
+"""This script is to generate training list files for Deep3DFaceRecon_pytorch
+"""
+
+import os
+
+# save path to training data
+def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
+ save_path = os.path.join(save_folder, mode)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+ with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in lms_list])
+
+ with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in imgs_list])
+
+ with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in msks_list])
+
+# check if the path is valid
+def check_list(rlms_list, rimgs_list, rmsks_list):
+ lms_list, imgs_list, msks_list = [], [], []
+ for i in range(len(rlms_list)):
+ flag = 'false'
+ lm_path = rlms_list[i]
+ im_path = rimgs_list[i]
+ msk_path = rmsks_list[i]
+ if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
+ flag = 'true'
+ lms_list.append(rlms_list[i])
+ imgs_list.append(rimgs_list[i])
+ msks_list.append(rmsks_list[i])
+ print(i, rlms_list[i], flag)
+ return lms_list, imgs_list, msks_list
diff --git a/src/face3d/util/html.py b/src/face3d/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc3262a1eafda34842e4dbad47bb6ba72f0c5a68
--- /dev/null
+++ b/src/face3d/util/html.py
@@ -0,0 +1,86 @@
+import dominate
+from dominate.tags import meta, h3, table, tr, td, p, a, img, br
+import os
+
+
+class HTML:
+ """This HTML class allows us to save images and write texts into a single HTML file.
+
+ It consists of functions such as (add a text header to the HTML file),
+ (add a row of images to the HTML file), and (save the HTML to the disk).
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+ """
+
+ def __init__(self, web_dir, title, refresh=0):
+ """Initialize the HTML classes
+
+ Parameters:
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ """Return the directory that stores images"""
+ return self.img_dir
+
+ def add_header(self, text):
+ """Insert a header to the HTML file
+
+ Parameters:
+ text (str) -- the header text
+ """
+ with self.doc:
+ h3(text)
+
+ def add_images(self, ims, txts, links, width=400):
+ """add images to the HTML file
+
+ Parameters:
+ ims (str list) -- a list of image paths
+ txts (str list) -- a list of image names shown on the website
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
+ """
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
+ self.doc.add(self.t)
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ """save the current content to the HMTL file"""
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__': # we show an example usage here.
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims, txts, links = [], [], []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/src/face3d/util/load_mats.py b/src/face3d/util/load_mats.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a6fcc71de1d7dad8b0f81c67dc1c213764ff0b
--- /dev/null
+++ b/src/face3d/util/load_mats.py
@@ -0,0 +1,120 @@
+"""This script is to load 3D face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from PIL import Image
+from scipy.io import loadmat, savemat
+from array import array
+import os.path as osp
+
+# load expression basis
+def LoadExpBasis(bfm_folder='BFM'):
+ n_vertex = 53215
+ Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
+ exp_dim = array('i')
+ exp_dim.fromfile(Expbin, 1)
+ expMU = array('f')
+ expPC = array('f')
+ expMU.fromfile(Expbin, 3*n_vertex)
+ expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
+ Expbin.close()
+
+ expPC = np.array(expPC)
+ expPC = np.reshape(expPC, [exp_dim[0], -1])
+ expPC = np.transpose(expPC)
+
+ expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
+
+ return expPC, expEV
+
+
+# transfer original BFM09 to our face model
+def transferBFM09(bfm_folder='BFM'):
+ print('Transfer BFM09 to BFM_model_front......')
+ original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
+ shapePC = original_BFM['shapePC'] # shape basis
+ shapeEV = original_BFM['shapeEV'] # corresponding eigen value
+ shapeMU = original_BFM['shapeMU'] # mean face
+ texPC = original_BFM['texPC'] # texture basis
+ texEV = original_BFM['texEV'] # eigen value
+ texMU = original_BFM['texMU'] # mean texture
+
+ expPC, expEV = LoadExpBasis(bfm_folder)
+
+ # transfer BFM09 to our face model
+
+ idBase = shapePC*np.reshape(shapeEV, [-1, 199])
+ idBase = idBase/1e5 # unify the scale to decimeter
+ idBase = idBase[:, :80] # use only first 80 basis
+
+ exBase = expPC*np.reshape(expEV, [-1, 79])
+ exBase = exBase/1e5 # unify the scale to decimeter
+ exBase = exBase[:, :64] # use only first 64 basis
+
+ texBase = texPC*np.reshape(texEV, [-1, 199])
+ texBase = texBase[:, :80] # use only first 80 basis
+
+ # our face model is cropped along face landmarks and contains only 35709 vertex.
+ # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
+ # thus we select corresponding vertex to get our face model.
+
+ index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
+ index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215)
+
+ index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
+ index_shape = index_shape['trimIndex'].astype(
+ np.int32) - 1 # starts from 0 (to 53490)
+ index_shape = index_shape[index_exp]
+
+ idBase = np.reshape(idBase, [-1, 3, 80])
+ idBase = idBase[index_shape, :, :]
+ idBase = np.reshape(idBase, [-1, 80])
+
+ texBase = np.reshape(texBase, [-1, 3, 80])
+ texBase = texBase[index_shape, :, :]
+ texBase = np.reshape(texBase, [-1, 80])
+
+ exBase = np.reshape(exBase, [-1, 3, 64])
+ exBase = exBase[index_exp, :, :]
+ exBase = np.reshape(exBase, [-1, 64])
+
+ meanshape = np.reshape(shapeMU, [-1, 3])/1e5
+ meanshape = meanshape[index_shape, :]
+ meanshape = np.reshape(meanshape, [1, -1])
+
+ meantex = np.reshape(texMU, [-1, 3])
+ meantex = meantex[index_shape, :]
+ meantex = np.reshape(meantex, [1, -1])
+
+ # other info contains triangles, region used for computing photometric loss,
+ # region used for skin texture regularization, and 68 landmarks index etc.
+ other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
+ frontmask2_idx = other_info['frontmask2_idx']
+ skinmask = other_info['skinmask']
+ keypoints = other_info['keypoints']
+ point_buf = other_info['point_buf']
+ tri = other_info['tri']
+ tri_mask2 = other_info['tri_mask2']
+
+ # save our face model
+ savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
+ 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
+
+
+# load landmarks for standard face, which is used for image preprocessing
+def load_lm3d(bfm_folder):
+
+ Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
+ Lm3D = Lm3D['lm']
+
+ # calculate 5 facial landmarks using 68 landmarks
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
+ Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
+ Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
+
+ return Lm3D
+
+
+if __name__ == '__main__':
+ transferBFM09()
\ No newline at end of file
diff --git a/src/face3d/util/nvdiffrast.py b/src/face3d/util/nvdiffrast.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3245859c650afbfe841a66b74cddefaf28820d9
--- /dev/null
+++ b/src/face3d/util/nvdiffrast.py
@@ -0,0 +1,126 @@
+"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
+ Attention, antialiasing step is missing in current version.
+"""
+import pytorch3d.ops
+import torch
+import torch.nn.functional as F
+import kornia
+from kornia.geometry.camera import pixel2cam
+import numpy as np
+from typing import List
+from scipy.io import loadmat
+from torch import nn
+
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import (
+ look_at_view_transform,
+ FoVPerspectiveCameras,
+ DirectionalLights,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+ TexturesUV,
+)
+
+# def ndc_projection(x=0.1, n=1.0, f=50.0):
+# return np.array([[n/x, 0, 0, 0],
+# [ 0, n/-x, 0, 0],
+# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+# [ 0, 0, -1, 0]]).astype(np.float32)
+
+class MeshRenderer(nn.Module):
+ def __init__(self,
+ rasterize_fov,
+ znear=0.1,
+ zfar=10,
+ rasterize_size=224):
+ super(MeshRenderer, self).__init__()
+
+ # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
+ # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
+ # torch.diag(torch.tensor([1., -1, -1, 1])))
+ self.rasterize_size = rasterize_size
+ self.fov = rasterize_fov
+ self.znear = znear
+ self.zfar = zfar
+
+ self.rasterizer = None
+
+ def forward(self, vertex, tri, feat=None):
+ """
+ Return:
+ mask -- torch.tensor, size (B, 1, H, W)
+ depth -- torch.tensor, size (B, 1, H, W)
+ features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
+
+ Parameters:
+ vertex -- torch.tensor, size (B, N, 3)
+ tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
+ feat(optional) -- torch.tensor, size (B, N ,C), features
+ """
+ device = vertex.device
+ rsize = int(self.rasterize_size)
+ # ndc_proj = self.ndc_proj.to(device)
+ # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
+ if vertex.shape[-1] == 3:
+ vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
+ vertex[..., 0] = -vertex[..., 0]
+
+
+ # vertex_ndc = vertex @ ndc_proj.t()
+ if self.rasterizer is None:
+ self.rasterizer = MeshRasterizer()
+ print("create rasterizer on device cuda:%d"%device.index)
+
+ # ranges = None
+ # if isinstance(tri, List) or len(tri.shape) == 3:
+ # vum = vertex_ndc.shape[1]
+ # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
+ # fstartidx = torch.cumsum(fnum, dim=0) - fnum
+ # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
+ # for i in range(tri.shape[0]):
+ # tri[i] = tri[i] + i*vum
+ # vertex_ndc = torch.cat(vertex_ndc, dim=0)
+ # tri = torch.cat(tri, dim=0)
+
+ # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
+ tri = tri.type(torch.int32).contiguous()
+
+ # rasterize
+ cameras = FoVPerspectiveCameras(
+ device=device,
+ fov=self.fov,
+ znear=self.znear,
+ zfar=self.zfar,
+ )
+
+ raster_settings = RasterizationSettings(
+ image_size=rsize
+ )
+
+ # print(vertex.shape, tri.shape)
+ mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1)))
+
+ fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
+ rast_out = fragments.pix_to_face.squeeze(-1)
+ depth = fragments.zbuf
+
+ # render depth
+ depth = depth.permute(0, 3, 1, 2)
+ mask = (rast_out > 0).float().unsqueeze(1)
+ depth = mask * depth
+
+
+ image = None
+ if feat is not None:
+ attributes = feat.reshape(-1,3)[mesh.faces_packed()]
+ image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
+ fragments.bary_coords,
+ attributes)
+ # print(image.shape)
+ image = image.squeeze(-2).permute(0, 3, 1, 2)
+ image = mask * image
+
+ return mask, depth, image
+
diff --git a/src/face3d/util/preprocess.py b/src/face3d/util/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b77a3a4058c208e5ba8cb1cfbb563954a5f7a3e2
--- /dev/null
+++ b/src/face3d/util/preprocess.py
@@ -0,0 +1,103 @@
+"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from scipy.io import loadmat
+from PIL import Image
+import cv2
+import os
+from skimage import transform as trans
+import torch
+import warnings
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+
+# calculating least square problem for image alignment
+def POS(xp, x):
+ npts = xp.shape[1]
+
+ A = np.zeros([2*npts, 8])
+
+ A[0:2*npts-1:2, 0:3] = x.transpose()
+ A[0:2*npts-1:2, 3] = 1
+
+ A[1:2*npts:2, 4:7] = x.transpose()
+ A[1:2*npts:2, 7] = 1
+
+ b = np.reshape(xp.transpose(), [2*npts, 1])
+
+ k, _, _, _ = np.linalg.lstsq(A, b)
+
+ R1 = k[0:3]
+ R2 = k[4:7]
+ sTx = k[3]
+ sTy = k[7]
+ s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
+ t = np.stack([sTx, sTy], axis=0)
+
+ return t, s
+
+# resize and crop images for face reconstruction
+def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+ w0, h0 = img.size
+ w = (w0*s).astype(np.int32)
+ h = (h0*s).astype(np.int32)
+ left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+ right = left + target_size
+ up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+ below = up + target_size
+
+ img = img.resize((w, h), resample=Image.BICUBIC)
+ img = img.crop((left, up, right, below))
+
+ if mask is not None:
+ mask = mask.resize((w, h), resample=Image.BICUBIC)
+ mask = mask.crop((left, up, right, below))
+
+ lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+ t[1] + h0/2], axis=1)*s
+ lm = lm - np.reshape(
+ np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+ return img, lm, mask
+
+# utils for face reconstruction
+def extract_5p(lm):
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
+ lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
+ lm5p = lm5p[[1, 2, 0, 3, 4], :]
+ return lm5p
+
+# utils for face reconstruction
+def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
+ """
+ Return:
+ transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
+ img_new --PIL.Image (target_size, target_size, 3)
+ lm_new --numpy.array (68, 2), y direction is opposite to v direction
+ mask_new --PIL.Image (target_size, target_size)
+
+ Parameters:
+ img --PIL.Image (raw_H, raw_W, 3)
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ lm3D --numpy.array (5, 3)
+ mask --PIL.Image (raw_H, raw_W, 3)
+ """
+
+ w0, h0 = img.size
+ if lm.shape[0] != 5:
+ lm5p = extract_5p(lm)
+ else:
+ lm5p = lm
+
+ # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
+ t, s = POS(lm5p.transpose(), lm3D.transpose())
+ s = rescale_factor/s
+
+ # processing the image
+ img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ trans_params = np.array([w0, h0, s, t[0], t[1]])
+
+ return trans_params, img_new, lm_new, mask_new
diff --git a/src/face3d/util/skin_mask.py b/src/face3d/util/skin_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8a74e4c3b40d13b0258b83a12f56321a85bb179
--- /dev/null
+++ b/src/face3d/util/skin_mask.py
@@ -0,0 +1,125 @@
+"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
+"""
+
+import math
+import numpy as np
+import os
+import cv2
+
+class GMM:
+ def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
+ self.dim = dim # feature dimension
+ self.num = num # number of Gaussian components
+ self.w = w # weights of Gaussian components (a list of scalars)
+ self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)
+ self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
+ self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
+ self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
+
+ self.factor = [0]*num
+ for i in range(self.num):
+ self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5
+
+ def likelihood(self, data):
+ assert(data.shape[1] == self.dim)
+ N = data.shape[0]
+ lh = np.zeros(N)
+
+ for i in range(self.num):
+ data_ = data - self.mu[i]
+
+ tmp = np.matmul(data_,self.cov_inv[i]) * data_
+ tmp = np.sum(tmp,axis=1)
+ power = -0.5 * tmp
+
+ p = np.array([math.exp(power[j]) for j in range(N)])
+ p = p/self.factor[i]
+ lh += p*self.w[i]
+
+ return lh
+
+
+def _rgb2ycbcr(rgb):
+ m = np.array([[65.481, 128.553, 24.966],
+ [-37.797, -74.203, 112],
+ [112, -93.786, -18.214]])
+ shape = rgb.shape
+ rgb = rgb.reshape((shape[0] * shape[1], 3))
+ ycbcr = np.dot(rgb, m.transpose() / 255.)
+ ycbcr[:, 0] += 16.
+ ycbcr[:, 1:] += 128.
+ return ycbcr.reshape(shape)
+
+
+def _bgr2ycbcr(bgr):
+ rgb = bgr[..., ::-1]
+ return _rgb2ycbcr(rgb)
+
+
+gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
+gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),
+ np.array([150.19858, 105.18467, 155.51428]),
+ np.array([183.92976, 107.62468, 152.71820]),
+ np.array([114.90524, 113.59782, 151.38217])]
+gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]
+gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),
+ np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),
+ np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),
+ np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]
+
+gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
+
+gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
+gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),
+ np.array([110.91392, 125.52969, 130.19237]),
+ np.array([129.75864, 129.96107, 126.96808]),
+ np.array([112.29587, 128.85121, 129.05431])]
+gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]
+gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),
+ np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),
+ np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),
+ np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]
+
+gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
+
+prior_skin = 0.8
+prior_nonskin = 1 - prior_skin
+
+
+# calculate skin attention mask
+def skinmask(imbgr):
+ im = _bgr2ycbcr(imbgr)
+
+ data = im.reshape((-1,3))
+
+ lh_skin = gmm_skin.likelihood(data)
+ lh_nonskin = gmm_nonskin.likelihood(data)
+
+ tmp1 = prior_skin * lh_skin
+ tmp2 = prior_nonskin * lh_nonskin
+ post_skin = tmp1 / (tmp1+tmp2) # posterior probability
+
+ post_skin = post_skin.reshape((im.shape[0],im.shape[1]))
+
+ post_skin = np.round(post_skin*255)
+ post_skin = post_skin.astype(np.uint8)
+ post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3
+
+ return post_skin
+
+
+def get_skin_mask(img_path):
+ print('generating skin masks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ save_path = os.path.join(img_path, 'mask')
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ img = cv2.imread(full_image_name).astype(np.float32)
+ skin_img = skinmask(img)
+ cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
diff --git a/src/face3d/util/test_mean_face.txt b/src/face3d/util/test_mean_face.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3a46d4db7699ffed8f898fcee64099631509946d
--- /dev/null
+++ b/src/face3d/util/test_mean_face.txt
@@ -0,0 +1,136 @@
+-5.228591537475585938e+01
+2.078247070312500000e-01
+-5.064269638061523438e+01
+-1.315765380859375000e+01
+-4.952939224243164062e+01
+-2.592591094970703125e+01
+-4.793047332763671875e+01
+-3.832135772705078125e+01
+-4.512159729003906250e+01
+-5.059623336791992188e+01
+-3.917720794677734375e+01
+-6.043736648559570312e+01
+-2.929953765869140625e+01
+-6.861183166503906250e+01
+-1.719801330566406250e+01
+-7.572736358642578125e+01
+-1.961936950683593750e+00
+-7.862001037597656250e+01
+1.467941284179687500e+01
+-7.607844543457031250e+01
+2.744073486328125000e+01
+-6.915261840820312500e+01
+3.855677795410156250e+01
+-5.950350570678710938e+01
+4.478240966796875000e+01
+-4.867547225952148438e+01
+4.714337158203125000e+01
+-3.800830078125000000e+01
+4.940315246582031250e+01
+-2.496297454833984375e+01
+5.117234802246093750e+01
+-1.241538238525390625e+01
+5.190507507324218750e+01
+8.244247436523437500e-01
+-4.150688934326171875e+01
+2.386329650878906250e+01
+-3.570307159423828125e+01
+3.017010498046875000e+01
+-2.790358734130859375e+01
+3.212951660156250000e+01
+-1.941773223876953125e+01
+3.156523132324218750e+01
+-1.138106536865234375e+01
+2.841992187500000000e+01
+5.993263244628906250e+00
+2.895182800292968750e+01
+1.343590545654296875e+01
+3.189880371093750000e+01
+2.203153991699218750e+01
+3.302221679687500000e+01
+2.992478942871093750e+01
+3.099150085449218750e+01
+3.628388977050781250e+01
+2.765748596191406250e+01
+-1.933914184570312500e+00
+1.405374145507812500e+01
+-2.153038024902343750e+00
+5.772636413574218750e+00
+-2.270050048828125000e+00
+-2.121643066406250000e+00
+-2.218330383300781250e+00
+-1.068978118896484375e+01
+-1.187252044677734375e+01
+-1.997912597656250000e+01
+-6.879402160644531250e+00
+-2.143579864501953125e+01
+-1.227821350097656250e+00
+-2.193494415283203125e+01
+4.623237609863281250e+00
+-2.152721405029296875e+01
+9.721397399902343750e+00
+-1.953671264648437500e+01
+-3.648714447021484375e+01
+9.811126708984375000e+00
+-3.130242919921875000e+01
+1.422447967529296875e+01
+-2.212834930419921875e+01
+1.493019866943359375e+01
+-1.500880432128906250e+01
+1.073588562011718750e+01
+-2.095037078857421875e+01
+9.054298400878906250e+00
+-3.050099182128906250e+01
+8.704177856445312500e+00
+1.173237609863281250e+01
+1.054329681396484375e+01
+1.856353759765625000e+01
+1.535009765625000000e+01
+2.893331909179687500e+01
+1.451992797851562500e+01
+3.452944946289062500e+01
+1.065280151367187500e+01
+2.875990295410156250e+01
+8.654792785644531250e+00
+1.942100524902343750e+01
+9.422447204589843750e+00
+-2.204488372802734375e+01
+-3.983994293212890625e+01
+-1.324458312988281250e+01
+-3.467377471923828125e+01
+-6.749649047851562500e+00
+-3.092894744873046875e+01
+-9.183349609375000000e-01
+-3.196458435058593750e+01
+4.220649719238281250e+00
+-3.090406036376953125e+01
+1.089889526367187500e+01
+-3.497008514404296875e+01
+1.874589538574218750e+01
+-4.065438079833984375e+01
+1.124106597900390625e+01
+-4.438417816162109375e+01
+5.181709289550781250e+00
+-4.649170684814453125e+01
+-1.158607482910156250e+00
+-4.680406951904296875e+01
+-7.918922424316406250e+00
+-4.671575164794921875e+01
+-1.452505493164062500e+01
+-4.416526031494140625e+01
+-2.005007171630859375e+01
+-3.997841644287109375e+01
+-1.054919433593750000e+01
+-3.849683380126953125e+01
+-1.051826477050781250e+00
+-3.794863128662109375e+01
+6.412681579589843750e+00
+-3.804645538330078125e+01
+1.627674865722656250e+01
+-4.039697265625000000e+01
+6.373878479003906250e+00
+-4.087213897705078125e+01
+-8.551712036132812500e-01
+-4.157129669189453125e+01
+-1.014953613281250000e+01
+-4.128469085693359375e+01
diff --git a/src/face3d/util/util.py b/src/face3d/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d689ca138fc0fbf5bec794511ea0f9e638f9ea9
--- /dev/null
+++ b/src/face3d/util/util.py
@@ -0,0 +1,208 @@
+"""This script contains basic utilities for Deep3DFaceRecon_pytorch
+"""
+from __future__ import print_function
+import numpy as np
+import torch
+from PIL import Image
+import os
+import importlib
+import argparse
+from argparse import Namespace
+import torchvision
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def copyconf(default_opt, **kwargs):
+ conf = Namespace(**vars(default_opt))
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+ return conf
+
+def genvalconf(train_opt, **kwargs):
+ conf = Namespace(**vars(train_opt))
+ attr_dict = train_opt.__dict__
+ for key, value in attr_dict.items():
+ if 'val' in key and key.split('_')[0] in attr_dict:
+ setattr(conf, key.split('_')[0], value)
+
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+
+ return conf
+
+def find_class_in_module(target_cls_name, module):
+ target_cls_name = target_cls_name.replace('_', '').lower()
+ clslib = importlib.import_module(module)
+ cls = None
+ for name, clsobj in clslib.__dict__.items():
+ if name.lower() == target_cls_name:
+ cls = clsobj
+
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
+
+ return cls
+
+
+def tensor2im(input_image, imtype=np.uint8):
+ """"Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array, range(0, 1)
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
+ if image_numpy.shape[0] == 1: # grayscale to RGB
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+
+ image_pil = Image.fromarray(image_numpy)
+ h, w, _ = image_numpy.shape
+
+ if aspect_ratio is None:
+ pass
+ elif aspect_ratio > 1.0:
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+ elif aspect_ratio < 1.0:
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def correct_resize_label(t, size):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i, :1]
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
+ one_np = one_np[:, :, 0]
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
+ resized_t = torch.from_numpy(np.array(one_image)).long()
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+
+def correct_resize(t, size, mode=Image.BICUBIC):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i:i + 1]
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+def draw_landmarks(img, landmark, color='r', step=2):
+ """
+ Return:
+ img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
+
+
+ Parameters:
+ img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
+ landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
+ color -- str, 'r' or 'b' (red or blue)
+ """
+ if color =='r':
+ c = np.array([255., 0, 0])
+ else:
+ c = np.array([0, 0, 255.])
+
+ _, H, W, _ = img.shape
+ img, landmark = img.copy(), landmark.copy()
+ landmark[..., 1] = H - 1 - landmark[..., 1]
+ landmark = np.round(landmark).astype(np.int32)
+ for i in range(landmark.shape[1]):
+ x, y = landmark[:, i, 0], landmark[:, i, 1]
+ for j in range(-step, step):
+ for k in range(-step, step):
+ u = np.clip(x + j, 0, W - 1)
+ v = np.clip(y + k, 0, H - 1)
+ for m in range(landmark.shape[0]):
+ img[m, v[m], u[m]] = c
+ return img
diff --git a/src/face3d/util/visualizer.py b/src/face3d/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4023a6d4086acba9bc88e079f625194d324d7c9e
--- /dev/null
+++ b/src/face3d/util/visualizer.py
@@ -0,0 +1,227 @@
+"""This script defines the visualizer for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import os
+import sys
+import ntpath
+import time
+from . import util, html
+from subprocess import Popen, PIPE
+from torch.utils.tensorboard import SummaryWriter
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ """Save images to the disk.
+
+ Parameters:
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+ image_path (str) -- the string is used to create image paths
+ aspect_ratio (float) -- the aspect ratio of saved images
+ width (int) -- the images will be resized to width x width
+
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+ """
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ image_name = '%s/%s.png' % (label, name)
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+ """This class includes several functions that can display/save images and print/save logging information.
+
+ It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+ """
+
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the option
+ self.use_html = opt.isTrain and not opt.no_html
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.saved = False
+ if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ def reset(self):
+ """Reset the self.saved status"""
+ self.saved = False
+
+
+ def display_current_results(self, visuals, total_iters, epoch, save_result):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ save_result (bool) - - if save the current results to an HTML file
+ """
+ for label, image in visuals.items():
+ self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')
+
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
+ self.saved = True
+ # save images to the disk
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ def plot_current_losses(self, total_iters, losses):
+ # G_loss_collection = {}
+ # D_loss_collection = {}
+ # for name, value in losses.items():
+ # if 'G' in name or 'NCE' in name or 'idt' in name:
+ # G_loss_collection[name] = value
+ # else:
+ # D_loss_collection[name] = value
+ # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
+ # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
+ for name, value in losses.items():
+ self.writer.add_scalar(name, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
+
+
+class MyVisualizer:
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the optio
+ self.name = opt.name
+ self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')
+
+ if opt.phase != 'test':
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+
+ def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,
+ add_image=True):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ dataset (str) - - 'train' or 'val' or 'test'
+ """
+ # if (not add_image) and (not save_results): return
+
+ for label, image in visuals.items():
+ for i in range(image.shape[0]):
+ image_numpy = util.tensor2im(image[i])
+ if add_image:
+ self.writer.add_image(label + '%s_%02d'%(dataset, i + count),
+ image_numpy, total_iters, dataformats='HWC')
+
+ if save_results:
+ save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ if name is not None:
+ img_path = os.path.join(save_path, '%s.png' % name)
+ else:
+ img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))
+ util.save_image(image_numpy, img_path)
+
+
+ def plot_current_losses(self, total_iters, losses, dataset='train'):
+ for name, value in losses.items():
+ self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (
+ dataset, epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
diff --git a/src/face3d/visualize.py b/src/face3d/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..23a1110806a0ddf37d4aa549c023d1c3f7114e3e
--- /dev/null
+++ b/src/face3d/visualize.py
@@ -0,0 +1,48 @@
+# check the sync of 3dmm feature and the audio
+import cv2
+import numpy as np
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.facerecon_model import FaceReconModel
+import torch
+import subprocess, platform
+import scipy.io as scio
+from tqdm import tqdm
+
+# draft
+def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64):
+
+ coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
+
+ coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
+
+ coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
+
+ coeff_full[:, 80:144] = coeff_pred[:, 0:64]
+ coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
+ coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
+
+ tmp_video_path = '/tmp/face3dtmp.mp4'
+
+ facemodel = FaceReconModel(args)
+
+ video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
+
+ for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
+ cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ rendered_img = facemodel.pred_face
+ rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
+ out_img = rendered_img[:, :, :3].astype(np.uint8)
+
+ video.write(np.uint8(out_img[:,:,::-1]))
+
+ video.release()
+
+ command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
+ subprocess.call(command, shell=platform.system() != 'Windows')
+
diff --git a/src/facerender/animate.py b/src/facerender/animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..781f5a3318a086049cc6b74393073ddda7001d5e
--- /dev/null
+++ b/src/facerender/animate.py
@@ -0,0 +1,257 @@
+import os
+import cv2
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+from src.facerender.modules.make_animation import make_animation
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+class AnimateFromCoeff():
+
+ def __init__(self, sadtalker_path, device):
+
+ with open(sadtalker_path['facerender_yaml']) as f:
+ config = yaml.safe_load(f)
+
+ generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+ kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+ **config['model_params']['common_params'])
+ he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+ **config['model_params']['common_params'])
+ mapping = MappingNet(**config['model_params']['mapping_params'])
+
+ generator.to(device)
+ kp_extractor.to(device)
+ he_estimator.to(device)
+ mapping.to(device)
+ for param in generator.parameters():
+ param.requires_grad = False
+ for param in kp_extractor.parameters():
+ param.requires_grad = False
+ for param in he_estimator.parameters():
+ param.requires_grad = False
+ for param in mapping.parameters():
+ param.requires_grad = False
+
+ if sadtalker_path is not None:
+ if 'checkpoint' in sadtalker_path: # use safe tensor
+ self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
+ else:
+ self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+ else:
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
+
+ if sadtalker_path['mappingnet_checkpoint'] is not None:
+ self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
+ else:
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
+
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.he_estimator = he_estimator
+ self.mapping = mapping
+
+ self.kp_extractor.eval()
+ self.generator.eval()
+ self.he_estimator.eval()
+ self.mapping.eval()
+
+ self.device = device
+
+ def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
+ def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
+ optimizer_discriminator=None, optimizer_kp_detector=None,
+ optimizer_he_estimator=None, device="cpu"):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if generator is not None:
+ generator.load_state_dict(checkpoint['generator'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if he_estimator is not None:
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
+ if discriminator is not None:
+ try:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ except:
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+ if optimizer_generator is not None:
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+ if optimizer_discriminator is not None:
+ try:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+ except RuntimeError as e:
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+ if optimizer_kp_detector is not None:
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+ if optimizer_he_estimator is not None:
+ optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+ return checkpoint['epoch']
+
+ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if mapping is not None:
+ mapping.load_state_dict(checkpoint['mapping'])
+ if discriminator is not None:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ if optimizer_mapping is not None:
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
+ if optimizer_discriminator is not None:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+
+ return checkpoint['epoch']
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ if 'yaw_c_seq' in x:
+ yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
+ yaw_c_seq = x['yaw_c_seq'].to(self.device)
+ else:
+ yaw_c_seq = None
+ if 'pitch_c_seq' in x:
+ pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
+ pitch_c_seq = x['pitch_c_seq'].to(self.device)
+ else:
+ pitch_c_seq = None
+ if 'roll_c_seq' in x:
+ roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
+ roll_c_seq = x['roll_c_seq'].to(self.device)
+ else:
+ roll_c_seq = None
+
+ frame_num = x['frame_num']
+
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
+ self.generator, self.kp_extractor, self.he_estimator, self.mapping,
+ yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
+
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+ predictions_video = predictions_video[:frame_num]
+
+ video = []
+ for idx in range(predictions_video.shape[0]):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+ os.remove(enhanced_path)
+
+ os.remove(path)
+ os.remove(new_audio_path)
+
+ return return_path
+
diff --git a/src/facerender/modules/dense_motion.py b/src/facerender/modules/dense_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..a286ead2e84ed1961335d34a3b50ab38f25e4495
--- /dev/null
+++ b/src/facerender/modules/dense_motion.py
@@ -0,0 +1,121 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+
+class DenseMotionNetwork(nn.Module):
+ """
+ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
+ """
+
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,
+ estimate_occlusion_map=False):
+ super(DenseMotionNetwork, self).__init__()
+ # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)
+
+ self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)
+
+ self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)
+ self.norm = BatchNorm3d(compress, affine=True)
+
+ if estimate_occlusion_map:
+ # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
+ else:
+ self.occlusion = None
+
+ self.num_kp = num_kp
+
+
+ def create_sparse_motions(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape
+ identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())
+ identity_grid = identity_grid.view(1, 1, d, h, w, 3)
+ coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)
+
+ # if 'jacobian' in kp_driving:
+ if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:
+ jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
+ jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
+ jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)
+ coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
+ coordinate_grid = coordinate_grid.squeeze(-1)
+
+
+ driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
+
+ #adding background feature
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3
+
+ # sparse_motions = driving_to_source
+
+ return sparse_motions
+
+ def create_deformed_feature(self, feature, sparse_motions):
+ bs, _, d, h, w = feature.shape
+ feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
+ feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
+ sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!!
+ sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
+ return sparse_deformed
+
+ def create_heatmap_representations(self, feature, kp_driving, kp_source):
+ spatial_size = feature.shape[3:]
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)
+ heatmap = gaussian_driving - gaussian_source
+
+ # adding background feature
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())
+ heatmap = torch.cat([zeros, heatmap], dim=1)
+ heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
+ return heatmap
+
+ def forward(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape
+
+ feature = self.compress(feature)
+ feature = self.norm(feature)
+ feature = F.relu(feature)
+
+ out_dict = dict()
+ sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)
+ deformed_feature = self.create_deformed_feature(feature, sparse_motion)
+
+ heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)
+
+ input_ = torch.cat([heatmap, deformed_feature], dim=2)
+ input_ = input_.view(bs, -1, d, h, w)
+
+ # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w)
+
+ prediction = self.hourglass(input_)
+
+
+ mask = self.mask(prediction)
+ mask = F.softmax(mask, dim=1)
+ out_dict['mask'] = mask
+ mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
+
+ zeros_mask = torch.zeros_like(mask)
+ mask = torch.where(mask < 1e-3, zeros_mask, mask)
+
+ sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
+ deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w)
+ deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
+
+ out_dict['deformation'] = deformation
+
+ if self.occlusion:
+ bs, c, d, h, w = prediction.shape
+ prediction = prediction.view(bs, -1, h, w)
+ occlusion_map = torch.sigmoid(self.occlusion(prediction))
+ out_dict['occlusion_map'] = occlusion_map
+
+ return out_dict
diff --git a/src/facerender/modules/discriminator.py b/src/facerender/modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4459b07cb075c9f9d345f9b3dffc02cd859313b
--- /dev/null
+++ b/src/facerender/modules/discriminator.py
@@ -0,0 +1,90 @@
+from torch import nn
+import torch.nn.functional as F
+from facerender.modules.util import kp2gaussian
+import torch
+
+
+class DownBlock2d(nn.Module):
+ """
+ Simple block for processing video (encoder).
+ """
+
+ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
+
+ if sn:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ if norm:
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
+ else:
+ self.norm = None
+ self.pool = pool
+
+ def forward(self, x):
+ out = x
+ out = self.conv(out)
+ if self.norm:
+ out = self.norm(out)
+ out = F.leaky_relu(out, 0.2)
+ if self.pool:
+ out = F.avg_pool2d(out, (2, 2))
+ return out
+
+
+class Discriminator(nn.Module):
+ """
+ Discriminator similar to Pix2Pix
+ """
+
+ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
+ sn=False, **kwargs):
+ super(Discriminator, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(
+ DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
+ if sn:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ def forward(self, x):
+ feature_maps = []
+ out = x
+
+ for down_block in self.down_blocks:
+ feature_maps.append(down_block(out))
+ out = feature_maps[-1]
+ prediction_map = self.conv(out)
+
+ return feature_maps, prediction_map
+
+
+class MultiScaleDiscriminator(nn.Module):
+ """
+ Multi-scale (scale) discriminator
+ """
+
+ def __init__(self, scales=(), **kwargs):
+ super(MultiScaleDiscriminator, self).__init__()
+ self.scales = scales
+ discs = {}
+ for scale in scales:
+ discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
+ self.discs = nn.ModuleDict(discs)
+
+ def forward(self, x):
+ out_dict = {}
+ for scale, disc in self.discs.items():
+ scale = str(scale).replace('-', '.')
+ key = 'prediction_' + scale
+ feature_maps, prediction_map = disc(x[key])
+ out_dict['feature_maps_' + scale] = feature_maps
+ out_dict['prediction_map_' + scale] = prediction_map
+ return out_dict
diff --git a/src/facerender/modules/generator.py b/src/facerender/modules/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a9edcb3b328d3afc99072b2461d7ca69919f813
--- /dev/null
+++ b/src/facerender/modules/generator.py
@@ -0,0 +1,255 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock
+from src.facerender.modules.dense_motion import DenseMotionNetwork
+
+
+class OcclusionAwareGenerator(nn.Module):
+ """
+ Generator follows NVIDIA architecture.
+ """
+
+ def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+ num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+ super(OcclusionAwareGenerator, self).__init__()
+
+ if dense_motion_params is not None:
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+ estimate_occlusion_map=estimate_occlusion_map,
+ **dense_motion_params)
+ else:
+ self.dense_motion_network = None
+
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))
+
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+ self.reshape_channel = reshape_channel
+ self.reshape_depth = reshape_depth
+
+ self.resblocks_3d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+ out_features = block_expansion * (2 ** (num_down_blocks))
+ self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+ self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+ self.resblocks_2d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))
+
+ up_blocks = []
+ for i in range(num_down_blocks):
+ in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))
+ out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))
+ up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))
+ self.estimate_occlusion_map = estimate_occlusion_map
+ self.image_channel = image_channel
+
+ def deform_input(self, inp, deformation):
+ _, d_old, h_old, w_old, _ = deformation.shape
+ _, _, d, h, w = inp.shape
+ if d_old != d or h_old != h or w_old != w:
+ deformation = deformation.permute(0, 4, 1, 2, 3)
+ deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+ deformation = deformation.permute(0, 2, 3, 4, 1)
+ return F.grid_sample(inp, deformation)
+
+ def forward(self, source_image, kp_driving, kp_source):
+ # Encoding (downsampling) part
+ out = self.first(source_image)
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ out = self.second(out)
+ bs, c, h, w = out.shape
+ # print(out.shape)
+ feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
+ feature_3d = self.resblocks_3d(feature_3d)
+
+ # Transforming feature representation according to deformation and occlusion
+ output_dict = {}
+ if self.dense_motion_network is not None:
+ dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+ kp_source=kp_source)
+ output_dict['mask'] = dense_motion['mask']
+
+ if 'occlusion_map' in dense_motion:
+ occlusion_map = dense_motion['occlusion_map']
+ output_dict['occlusion_map'] = occlusion_map
+ else:
+ occlusion_map = None
+ deformation = dense_motion['deformation']
+ out = self.deform_input(feature_3d, deformation)
+
+ bs, c, d, h, w = out.shape
+ out = out.view(bs, c*d, h, w)
+ out = self.third(out)
+ out = self.fourth(out)
+
+ if occlusion_map is not None:
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+ out = out * occlusion_map
+
+ # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image
+
+ # Decoding part
+ out = self.resblocks_2d(out)
+ for i in range(len(self.up_blocks)):
+ out = self.up_blocks[i](out)
+ out = self.final(out)
+ out = F.sigmoid(out)
+
+ output_dict["prediction"] = out
+
+ return output_dict
+
+
+class SPADEDecoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ ic = 256
+ oc = 64
+ norm_G = 'spadespectralinstance'
+ label_nc = 256
+
+ self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
+ self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
+ self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
+ self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
+ self.up = nn.Upsample(scale_factor=2)
+
+ def forward(self, feature):
+ seg = feature
+ x = self.fc(feature)
+ x = self.G_middle_0(x, seg)
+ x = self.G_middle_1(x, seg)
+ x = self.G_middle_2(x, seg)
+ x = self.G_middle_3(x, seg)
+ x = self.G_middle_4(x, seg)
+ x = self.G_middle_5(x, seg)
+ x = self.up(x)
+ x = self.up_0(x, seg) # 256, 128, 128
+ x = self.up(x)
+ x = self.up_1(x, seg) # 64, 256, 256
+
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
+ # x = torch.tanh(x)
+ x = F.sigmoid(x)
+
+ return x
+
+
+class OcclusionAwareSPADEGenerator(nn.Module):
+
+ def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+ num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+ super(OcclusionAwareSPADEGenerator, self).__init__()
+
+ if dense_motion_params is not None:
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+ estimate_occlusion_map=estimate_occlusion_map,
+ **dense_motion_params)
+ else:
+ self.dense_motion_network = None
+
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
+
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+ self.reshape_channel = reshape_channel
+ self.reshape_depth = reshape_depth
+
+ self.resblocks_3d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+ out_features = block_expansion * (2 ** (num_down_blocks))
+ self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+ self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+ self.estimate_occlusion_map = estimate_occlusion_map
+ self.image_channel = image_channel
+
+ self.decoder = SPADEDecoder()
+
+ def deform_input(self, inp, deformation):
+ _, d_old, h_old, w_old, _ = deformation.shape
+ _, _, d, h, w = inp.shape
+ if d_old != d or h_old != h or w_old != w:
+ deformation = deformation.permute(0, 4, 1, 2, 3)
+ deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+ deformation = deformation.permute(0, 2, 3, 4, 1)
+ return F.grid_sample(inp, deformation)
+
+ def forward(self, source_image, kp_driving, kp_source):
+ # Encoding (downsampling) part
+ out = self.first(source_image)
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ out = self.second(out)
+ bs, c, h, w = out.shape
+ # print(out.shape)
+ feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
+ feature_3d = self.resblocks_3d(feature_3d)
+
+ # Transforming feature representation according to deformation and occlusion
+ output_dict = {}
+ if self.dense_motion_network is not None:
+ dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+ kp_source=kp_source)
+ output_dict['mask'] = dense_motion['mask']
+
+ # import pdb; pdb.set_trace()
+
+ if 'occlusion_map' in dense_motion:
+ occlusion_map = dense_motion['occlusion_map']
+ output_dict['occlusion_map'] = occlusion_map
+ else:
+ occlusion_map = None
+ deformation = dense_motion['deformation']
+ out = self.deform_input(feature_3d, deformation)
+
+ bs, c, d, h, w = out.shape
+ out = out.view(bs, c*d, h, w)
+ out = self.third(out)
+ out = self.fourth(out)
+
+ # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map)
+
+ if occlusion_map is not None:
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+ out = out * occlusion_map
+
+ # Decoding part
+ out = self.decoder(out)
+
+ output_dict["prediction"] = out
+
+ return output_dict
\ No newline at end of file
diff --git a/src/facerender/modules/keypoint_detector.py b/src/facerender/modules/keypoint_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..62a38a962b2f1a4326aac771aced353ec5e22a96
--- /dev/null
+++ b/src/facerender/modules/keypoint_detector.py
@@ -0,0 +1,179 @@
+from torch import nn
+import torch
+import torch.nn.functional as F
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck
+
+
+class KPDetector(nn.Module):
+ """
+ Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.
+ """
+
+ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,
+ num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):
+ super(KPDetector, self).__init__()
+
+ self.predictor = KPHourglass(block_expansion, in_features=image_channel,
+ max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)
+
+ # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)
+ self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)
+
+ if estimate_jacobian:
+ self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
+ # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)
+ self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)
+ '''
+ initial as:
+ [[1 0 0]
+ [0 1 0]
+ [0 0 1]]
+ '''
+ self.jacobian.weight.data.zero_()
+ self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
+ else:
+ self.jacobian = None
+
+ self.temperature = temperature
+ self.scale_factor = scale_factor
+ if self.scale_factor != 1:
+ self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)
+
+ def gaussian2kp(self, heatmap):
+ """
+ Extract the mean from a heatmap
+ """
+ shape = heatmap.shape
+ heatmap = heatmap.unsqueeze(-1)
+ grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
+ value = (heatmap * grid).sum(dim=(2, 3, 4))
+ kp = {'value': value}
+
+ return kp
+
+ def forward(self, x):
+ if self.scale_factor != 1:
+ x = self.down(x)
+
+ feature_map = self.predictor(x)
+ prediction = self.kp(feature_map)
+
+ final_shape = prediction.shape
+ heatmap = prediction.view(final_shape[0], final_shape[1], -1)
+ heatmap = F.softmax(heatmap / self.temperature, dim=2)
+ heatmap = heatmap.view(*final_shape)
+
+ out = self.gaussian2kp(heatmap)
+
+ if self.jacobian is not None:
+ jacobian_map = self.jacobian(feature_map)
+ jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],
+ final_shape[3], final_shape[4])
+ heatmap = heatmap.unsqueeze(2)
+
+ jacobian = heatmap * jacobian_map
+ jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)
+ jacobian = jacobian.sum(dim=-1)
+ jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)
+ out['jacobian'] = jacobian
+
+ return out
+
+
+class HEEstimator(nn.Module):
+ """
+ Estimating head pose and expression.
+ """
+
+ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):
+ super(HEEstimator, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)
+ self.norm1 = BatchNorm2d(block_expansion, affine=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)
+ self.norm2 = BatchNorm2d(256, affine=True)
+
+ self.block1 = nn.Sequential()
+ for i in range(3):
+ self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))
+
+ self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
+ self.norm3 = BatchNorm2d(512, affine=True)
+ self.block2 = ResBottleneck(in_features=512, stride=2)
+
+ self.block3 = nn.Sequential()
+ for i in range(3):
+ self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))
+
+ self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
+ self.norm4 = BatchNorm2d(1024, affine=True)
+ self.block4 = ResBottleneck(in_features=1024, stride=2)
+
+ self.block5 = nn.Sequential()
+ for i in range(5):
+ self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))
+
+ self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)
+ self.norm5 = BatchNorm2d(2048, affine=True)
+ self.block6 = ResBottleneck(in_features=2048, stride=2)
+
+ self.block7 = nn.Sequential()
+ for i in range(2):
+ self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))
+
+ self.fc_roll = nn.Linear(2048, num_bins)
+ self.fc_pitch = nn.Linear(2048, num_bins)
+ self.fc_yaw = nn.Linear(2048, num_bins)
+
+ self.fc_t = nn.Linear(2048, 3)
+
+ self.fc_exp = nn.Linear(2048, 3*num_kp)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = F.relu(out)
+ out = self.maxpool(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+
+ out = self.block1(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+ out = F.relu(out)
+ out = self.block2(out)
+
+ out = self.block3(out)
+
+ out = self.conv4(out)
+ out = self.norm4(out)
+ out = F.relu(out)
+ out = self.block4(out)
+
+ out = self.block5(out)
+
+ out = self.conv5(out)
+ out = self.norm5(out)
+ out = F.relu(out)
+ out = self.block6(out)
+
+ out = self.block7(out)
+
+ out = F.adaptive_avg_pool2d(out, 1)
+ out = out.view(out.shape[0], -1)
+
+ yaw = self.fc_roll(out)
+ pitch = self.fc_pitch(out)
+ roll = self.fc_yaw(out)
+ t = self.fc_t(out)
+ exp = self.fc_exp(out)
+
+ return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
+
diff --git a/src/facerender/modules/make_animation.py b/src/facerender/modules/make_animation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3360c53501a064f35d7db21a5361f89aa9658b42
--- /dev/null
+++ b/src/facerender/modules/make_animation.py
@@ -0,0 +1,170 @@
+from scipy.spatial import ConvexHull
+import torch
+import torch.nn.functional as F
+import numpy as np
+from tqdm import tqdm
+
+def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
+ use_relative_movement=False, use_relative_jacobian=False):
+ if adapt_movement_scale:
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
+ else:
+ adapt_movement_scale = 1
+
+ kp_new = {k: v for k, v in kp_driving.items()}
+
+ if use_relative_movement:
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
+ kp_value_diff *= adapt_movement_scale
+ kp_new['value'] = kp_value_diff + kp_source['value']
+
+ if use_relative_jacobian:
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
+
+ return kp_new
+
+def headpose_pred_to_degree(pred):
+ device = pred.device
+ idx_tensor = [idx for idx in range(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)
+ pred = F.softmax(pred)
+ degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+ return degree
+
+def get_rotation_matrix(yaw, pitch, roll):
+ yaw = yaw / 180 * 3.14
+ pitch = pitch / 180 * 3.14
+ roll = roll / 180 * 3.14
+
+ roll = roll.unsqueeze(1)
+ pitch = pitch.unsqueeze(1)
+ yaw = yaw.unsqueeze(1)
+
+ pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
+ torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
+ torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
+ pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+ yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),
+ torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
+ -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
+ yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+ roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),
+ torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
+ torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
+ roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+ rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
+
+ return rot_mat
+
+def keypoint_transformation(kp_canonical, he, wo_exp=False):
+ kp = kp_canonical['value'] # (bs, k, 3)
+ yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']
+ yaw = headpose_pred_to_degree(yaw)
+ pitch = headpose_pred_to_degree(pitch)
+ roll = headpose_pred_to_degree(roll)
+
+ if 'yaw_in' in he:
+ yaw = he['yaw_in']
+ if 'pitch_in' in he:
+ pitch = he['pitch_in']
+ if 'roll_in' in he:
+ roll = he['roll_in']
+
+ rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
+
+ t, exp = he['t'], he['exp']
+ if wo_exp:
+ exp = exp*0
+
+ # keypoint rotation
+ kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+ # keypoint translation
+ t[:, 0] = t[:, 0]*0
+ t[:, 2] = t[:, 2]*0
+ t = t.unsqueeze(1).repeat(1, kp.shape[1], 1)
+ kp_t = kp_rotated + t
+
+ # add expression deviation
+ exp = exp.view(exp.shape[0], -1, 3)
+ kp_transformed = kp_t + exp
+
+ return {'value': kp_transformed}
+
+
+
+def make_animation(source_image, source_semantics, target_semantics,
+ generator, kp_detector, he_estimator, mapping,
+ yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
+ use_exp=True, use_half=False):
+ with torch.no_grad():
+ predictions = []
+
+ kp_canonical = kp_detector(source_image)
+ he_source = mapping(source_semantics)
+ kp_source = keypoint_transformation(kp_canonical, he_source)
+
+ for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
+ # still check the dimension
+ # print(target_semantics.shape, source_semantics.shape)
+ target_semantics_frame = target_semantics[:, frame_idx]
+ he_driving = mapping(target_semantics_frame)
+ if yaw_c_seq is not None:
+ he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
+ if pitch_c_seq is not None:
+ he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
+ if roll_c_seq is not None:
+ he_driving['roll_in'] = roll_c_seq[:, frame_idx]
+
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
+
+ kp_norm = kp_driving
+ out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
+ '''
+ source_image_new = out['prediction'].squeeze(1)
+ kp_canonical_new = kp_detector(source_image_new)
+ he_source_new = he_estimator(source_image_new)
+ kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
+ kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
+ out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
+ '''
+ predictions.append(out['prediction'])
+ predictions_ts = torch.stack(predictions, dim=1)
+ return predictions_ts
+
+class AnimateModel(torch.nn.Module):
+ """
+ Merge all generator related updates into single model for better multi-gpu usage
+ """
+
+ def __init__(self, generator, kp_extractor, mapping):
+ super(AnimateModel, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.mapping = mapping
+
+ self.kp_extractor.eval()
+ self.generator.eval()
+ self.mapping.eval()
+
+ def forward(self, x):
+
+ source_image = x['source_image']
+ source_semantics = x['source_semantics']
+ target_semantics = x['target_semantics']
+ yaw_c_seq = x['yaw_c_seq']
+ pitch_c_seq = x['pitch_c_seq']
+ roll_c_seq = x['roll_c_seq']
+
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
+ self.generator, self.kp_extractor,
+ self.mapping, use_exp = True,
+ yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)
+
+ return predictions_video
\ No newline at end of file
diff --git a/src/facerender/modules/mapping.py b/src/facerender/modules/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e3a1c2d1770996080c08e9daafb346f05d7bcdd
--- /dev/null
+++ b/src/facerender/modules/mapping.py
@@ -0,0 +1,47 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MappingNet(nn.Module):
+ def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
+ super( MappingNet, self).__init__()
+
+ self.layer = layer
+ nonlinearity = nn.LeakyReLU(0.1)
+
+ self.first = nn.Sequential(
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+ for i in range(layer):
+ net = nn.Sequential(nonlinearity,
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+ setattr(self, 'encoder' + str(i), net)
+
+ self.pooling = nn.AdaptiveAvgPool1d(1)
+ self.output_nc = descriptor_nc
+
+ self.fc_roll = nn.Linear(descriptor_nc, num_bins)
+ self.fc_pitch = nn.Linear(descriptor_nc, num_bins)
+ self.fc_yaw = nn.Linear(descriptor_nc, num_bins)
+ self.fc_t = nn.Linear(descriptor_nc, 3)
+ self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp)
+
+ def forward(self, input_3dmm):
+ out = self.first(input_3dmm)
+ for i in range(self.layer):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out) + out[:,:,3:-3]
+ out = self.pooling(out)
+ out = out.view(out.shape[0], -1)
+ #print('out:', out.shape)
+
+ yaw = self.fc_yaw(out)
+ pitch = self.fc_pitch(out)
+ roll = self.fc_roll(out)
+ t = self.fc_t(out)
+ exp = self.fc_exp(out)
+
+ return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
\ No newline at end of file
diff --git a/src/facerender/modules/util.py b/src/facerender/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b916deefbb8b957ad6ab3cd7403c28513e5ae18e
--- /dev/null
+++ b/src/facerender/modules/util.py
@@ -0,0 +1,564 @@
+from torch import nn
+
+import torch.nn.functional as F
+import torch
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+import torch.nn.utils.spectral_norm as spectral_norm
+
+
+def kp2gaussian(kp, spatial_size, kp_variance):
+ """
+ Transform a keypoint into gaussian like representation
+ """
+ mean = kp['value']
+
+ coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
+ number_of_leading_dimensions = len(mean.shape) - 1
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
+ coordinate_grid = coordinate_grid.view(*shape)
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
+ coordinate_grid = coordinate_grid.repeat(*repeats)
+
+ # Preprocess kp shape
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
+ mean = mean.view(*shape)
+
+ mean_sub = (coordinate_grid - mean)
+
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
+
+ return out
+
+def make_coordinate_grid_2d(spatial_size, type):
+ """
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
+ """
+ h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+
+ return meshed
+
+
+def make_coordinate_grid(spatial_size, type):
+ d, h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+ z = torch.arange(d).type(type)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+ z = (2 * (z / (d - 1)) - 1)
+
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
+
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
+
+ return meshed
+
+
+class ResBottleneck(nn.Module):
+ def __init__(self, in_features, stride):
+ super(ResBottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1)
+ self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride)
+ self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1)
+ self.norm1 = BatchNorm2d(in_features//4, affine=True)
+ self.norm2 = BatchNorm2d(in_features//4, affine=True)
+ self.norm3 = BatchNorm2d(in_features, affine=True)
+
+ self.stride = stride
+ if self.stride != 1:
+ self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride)
+ self.norm4 = BatchNorm2d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv3(out)
+ out = self.norm3(out)
+ if self.stride != 1:
+ x = self.skip(x)
+ x = self.norm4(x)
+ out += x
+ out = F.relu(out)
+ return out
+
+
+class ResBlock2d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock2d, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = BatchNorm2d(in_features, affine=True)
+ self.norm2 = BatchNorm2d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class ResBlock3d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock3d, self).__init__()
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = BatchNorm3d(in_features, affine=True)
+ self.norm2 = BatchNorm3d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class UpBlock2d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock2d, self).__init__()
+
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+
+ def forward(self, x):
+ out = F.interpolate(x, scale_factor=2)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+class UpBlock3d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock3d, self).__init__()
+
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm3d(out_features, affine=True)
+
+ def forward(self, x):
+ # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear')
+ out = F.interpolate(x, scale_factor=(1, 2, 2))
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class DownBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class DownBlock3d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock3d, self).__init__()
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups, stride=(1, 2, 2))
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm3d(out_features, affine=True)
+ self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class SameBlock2d(nn.Module):
+ """
+ Simple block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
+ super(SameBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
+ kernel_size=kernel_size, padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+ if lrelu:
+ self.ac = nn.LeakyReLU()
+ else:
+ self.ac = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = self.ac(out)
+ return out
+
+
+class Encoder(nn.Module):
+ """
+ Hourglass Encoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Encoder, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ def forward(self, x):
+ outs = [x]
+ for down_block in self.down_blocks:
+ outs.append(down_block(outs[-1]))
+ return outs
+
+
+class Decoder(nn.Module):
+ """
+ Hourglass Decoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Decoder, self).__init__()
+
+ up_blocks = []
+
+ for i in range(num_blocks)[::-1]:
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+ out_filters = min(max_features, block_expansion * (2 ** i))
+ up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+ # self.out_filters = block_expansion
+ self.out_filters = block_expansion + in_features
+
+ self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
+ self.norm = BatchNorm3d(self.out_filters, affine=True)
+
+ def forward(self, x):
+ out = x.pop()
+ # for up_block in self.up_blocks[:-1]:
+ for up_block in self.up_blocks:
+ out = up_block(out)
+ skip = x.pop()
+ out = torch.cat([out, skip], dim=1)
+ # out = self.up_blocks[-1](out)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class Hourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Hourglass, self).__init__()
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
+ self.out_filters = self.decoder.out_filters
+
+ def forward(self, x):
+ return self.decoder(self.encoder(x))
+
+
+class KPHourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256):
+ super(KPHourglass, self).__init__()
+
+ self.down_blocks = nn.Sequential()
+ for i in range(num_blocks):
+ self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+
+ in_filters = min(max_features, block_expansion * (2 ** num_blocks))
+ self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1)
+
+ self.up_blocks = nn.Sequential()
+ for i in range(num_blocks):
+ in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i)))
+ out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))
+ self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.reshape_depth = reshape_depth
+ self.out_filters = out_filters
+
+ def forward(self, x):
+ out = self.down_blocks(x)
+ out = self.conv(out)
+ bs, c, h, w = out.shape
+ out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w)
+ out = self.up_blocks(out)
+
+ return out
+
+
+
+class AntiAliasInterpolation2d(nn.Module):
+ """
+ Band-limited downsampling, for better preservation of the input signal.
+ """
+ def __init__(self, channels, scale):
+ super(AntiAliasInterpolation2d, self).__init__()
+ sigma = (1 / scale - 1) / 2
+ kernel_size = 2 * round(sigma * 4) + 1
+ self.ka = kernel_size // 2
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
+
+ kernel_size = [kernel_size, kernel_size]
+ sigma = [sigma, sigma]
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid(
+ [
+ torch.arange(size, dtype=torch.float32)
+ for size in kernel_size
+ ]
+ )
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+ self.scale = scale
+ inv_scale = 1 / scale
+ self.int_inv_scale = int(inv_scale)
+
+ def forward(self, input):
+ if self.scale == 1.0:
+ return input
+
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
+
+ return out
+
+
+class SPADE(nn.Module):
+ def __init__(self, norm_nc, label_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
+ nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+
+ def forward(self, x, segmap):
+ normalized = self.param_free_norm(x)
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class SPADEResnetBlock(nn.Module):
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+ self.use_se = use_se
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+ # apply spectral norm if specified
+ if 'spectral' in norm_G:
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+ # define normalization layers
+ self.norm_0 = SPADE(fin, label_nc)
+ self.norm_1 = SPADE(fmiddle, label_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(fin, label_nc)
+
+ def forward(self, x, seg1):
+ x_s = self.shortcut(x, seg1)
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg1):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg1))
+ else:
+ x_s = x
+ return x_s
+
+ def actvn(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+class audio2image(nn.Module):
+ def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params):
+ super().__init__()
+ # Attributes
+ self.generator = generator
+ self.kp_extractor = kp_extractor
+ self.he_estimator_video = he_estimator_video
+ self.he_estimator_audio = he_estimator_audio
+ self.train_params = train_params
+
+ def headpose_pred_to_degree(self, pred):
+ device = pred.device
+ idx_tensor = [idx for idx in range(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).to(device)
+ pred = F.softmax(pred)
+ degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+
+ return degree
+
+ def get_rotation_matrix(self, yaw, pitch, roll):
+ yaw = yaw / 180 * 3.14
+ pitch = pitch / 180 * 3.14
+ roll = roll / 180 * 3.14
+
+ roll = roll.unsqueeze(1)
+ pitch = pitch.unsqueeze(1)
+ yaw = yaw.unsqueeze(1)
+
+ roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll),
+ torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),
+ torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)
+ roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+ pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch),
+ torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),
+ -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)
+ pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+ yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),
+ torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),
+ torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)
+ yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+ rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)
+
+ return rot_mat
+
+ def keypoint_transformation(self, kp_canonical, he):
+ kp = kp_canonical['value'] # (bs, k, 3)
+ yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']
+ t, exp = he['t'], he['exp']
+
+ yaw = self.headpose_pred_to_degree(yaw)
+ pitch = self.headpose_pred_to_degree(pitch)
+ roll = self.headpose_pred_to_degree(roll)
+
+ rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
+
+ # keypoint rotation
+ kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+
+
+ # keypoint translation
+ t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)
+ kp_t = kp_rotated + t
+
+ # add expression deviation
+ exp = exp.view(exp.shape[0], -1, 3)
+ kp_transformed = kp_t + exp
+
+ return {'value': kp_transformed}
+
+ def forward(self, source_image, target_audio):
+ pose_source = self.he_estimator_video(source_image)
+ pose_generated = self.he_estimator_audio(target_audio)
+ kp_canonical = self.kp_extractor(source_image)
+ kp_source = self.keypoint_transformation(kp_canonical, pose_source)
+ kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated)
+ generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated)
+ return generated
\ No newline at end of file
diff --git a/src/facerender/pirender/base_function.py b/src/facerender/pirender/base_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..49fe4cf3d07c4a22f7d7db4bf0a97ebddc87dd72
--- /dev/null
+++ b/src/facerender/pirender/base_function.py
@@ -0,0 +1,368 @@
+import sys
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Function
+from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
+
+
+class LayerNorm2d(nn.Module):
+ def __init__(self, n_out, affine=True):
+ super(LayerNorm2d, self).__init__()
+ self.n_out = n_out
+ self.affine = affine
+
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
+
+ def forward(self, x):
+ normalized_shape = x.size()[1:]
+ if self.affine:
+ return F.layer_norm(x, normalized_shape, \
+ self.weight.expand(normalized_shape),
+ self.bias.expand(normalized_shape))
+
+ else:
+ return F.layer_norm(x, normalized_shape)
+
+class ADAINHourglass(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
+ super(ADAINHourglass, self).__init__()
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
+ self.output_nc = self.decoder.output_nc
+
+ def forward(self, x, z):
+ return self.decoder(self.encoder(x, z), z)
+
+
+
+class ADAINEncoder(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoder, self).__init__()
+ self.layers = layers
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
+ for i in range(layers):
+ in_channels = min(ngf * (2**i), img_f)
+ out_channels = min(ngf *(2**(i+1)), img_f)
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
+ setattr(self, 'encoder' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = self.input_layer(x)
+ out_list = [out]
+ for i in range(self.layers):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out, z)
+ out_list.append(out)
+ return out_list
+
+class ADAINDecoder(nn.Module):
+ """docstring for ADAINDecoder"""
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
+
+ super(ADAINDecoder, self).__init__()
+ self.encoder_layers = encoder_layers
+ self.decoder_layers = decoder_layers
+ self.skip_connect = skip_connect
+ use_transpose = True
+
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
+ in_channels = min(ngf * (2**(i+1)), img_f)
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
+ out_channels = min(ngf * (2**i), img_f)
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
+ setattr(self, 'decoder' + str(i), model)
+
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
+
+ def forward(self, x, z):
+ out = x.pop() if self.skip_connect else x
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
+ model = getattr(self, 'decoder' + str(i))
+ out = model(out, z)
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
+ return out
+
+class ADAINEncoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoderBlock, self).__init__()
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
+
+
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(output_nc, feature_nc)
+ self.actvn = nonlinearity
+
+ def forward(self, x, z):
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
+ return x
+
+class ADAINDecoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINDecoderBlock, self).__init__()
+ # Attributes
+ self.actvn = nonlinearity
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
+
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
+ if use_transpose:
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
+ else:
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
+
+ # create conv layers
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
+ if use_transpose:
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
+ else:
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ # define normalization layers
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
+ self.norm_s = ADAIN(input_nc, feature_nc)
+
+ def forward(self, x, z):
+ x_s = self.shortcut(x, z)
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, z):
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
+ return x_s
+
+
+def spectral_norm(module, use_spect=True):
+ """use spectral normal layer to stable the training process"""
+ if use_spect:
+ return SpectralNorm(module)
+ else:
+ return module
+
+
+class ADAIN(nn.Module):
+ def __init__(self, norm_nc, feature_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+
+ nhidden = 128
+ use_bias=True
+
+ self.mlp_shared = nn.Sequential(
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
+ nn.ReLU()
+ )
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
+
+ def forward(self, x, feature):
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on feature
+ feature = feature.view(feature.size(0), -1)
+ actv = self.mlp_shared(feature)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
+ beta = beta.view(*beta.size()[:2], 1,1)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class FineEncoder(nn.Module):
+ """docstring for Encoder"""
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineEncoder, self).__init__()
+ self.layers = layers
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
+ for i in range(layers):
+ in_channels = min(ngf*(2**i), img_f)
+ out_channels = min(ngf*(2**(i+1)), img_f)
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'down' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x):
+ x = self.first(x)
+ out=[x]
+ for i in range(self.layers):
+ model = getattr(self, 'down'+str(i))
+ x = model(x)
+ out.append(x)
+ return out
+
+class FineDecoder(nn.Module):
+ """docstring for FineDecoder"""
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineDecoder, self).__init__()
+ self.layers = layers
+ for i in range(layers)[::-1]:
+ in_channels = min(ngf*(2**(i+1)), img_f)
+ out_channels = min(ngf*(2**i), img_f)
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
+
+ setattr(self, 'up' + str(i), up)
+ setattr(self, 'res' + str(i), res)
+ setattr(self, 'jump' + str(i), jump)
+
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
+
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = x.pop()
+ for i in range(self.layers)[::-1]:
+ res_model = getattr(self, 'res' + str(i))
+ up_model = getattr(self, 'up' + str(i))
+ jump_model = getattr(self, 'jump' + str(i))
+ out = res_model(out, z)
+ out = up_model(out)
+ out = jump_model(x.pop()) + out
+ out_image = self.final(out)
+ return out_image
+
+class FirstBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FirstBlock2d, self).__init__()
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class DownBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(DownBlock2d, self).__init__()
+
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity, pool)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class UpBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(UpBlock2d, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(F.interpolate(x, scale_factor=2))
+ return out
+
+class FineADAINResBlocks(nn.Module):
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlocks, self).__init__()
+ self.num_block = num_block
+ for i in range(num_block):
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'res'+str(i), model)
+
+ def forward(self, x, z):
+ for i in range(self.num_block):
+ model = getattr(self, 'res'+str(i))
+ x = model(x, z)
+ return x
+
+class Jump(nn.Module):
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(Jump, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class FineADAINResBlock2d(nn.Module):
+ """
+ Define an Residual block for different types
+ """
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.norm1 = ADAIN(input_nc, feature_nc)
+ self.norm2 = ADAIN(input_nc, feature_nc)
+
+ self.actvn = nonlinearity
+
+
+ def forward(self, x, z):
+ dx = self.actvn(self.norm1(self.conv1(x), z))
+ dx = self.norm2(self.conv2(x), z)
+ out = dx + x
+ return out
+
+class FinalBlock2d(nn.Module):
+ """
+ Define the output layer
+ """
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
+ super(FinalBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if tanh_or_sigmoid == 'sigmoid':
+ out_nonlinearity = nn.Sigmoid()
+ else:
+ out_nonlinearity = nn.Tanh()
+
+ self.model = nn.Sequential(conv, out_nonlinearity)
+ def forward(self, x):
+ out = self.model(x)
+ return out
\ No newline at end of file
diff --git a/src/facerender/pirender/config.py b/src/facerender/pirender/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3f917385b5b1f7ed2809d963d3ad0f0c754459b
--- /dev/null
+++ b/src/facerender/pirender/config.py
@@ -0,0 +1,211 @@
+import collections
+import functools
+import os
+import re
+
+import yaml
+
+class AttrDict(dict):
+ """Dict as attribute trick."""
+
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+ for key, value in self.__dict__.items():
+ if isinstance(value, dict):
+ self.__dict__[key] = AttrDict(value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ self.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ self.__dict__[key] = value
+
+ def yaml(self):
+ """Convert object to yaml dict and return."""
+ yaml_dict = {}
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ yaml_dict[key] = value.yaml()
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ new_l = []
+ for item in value:
+ new_l.append(item.yaml())
+ yaml_dict[key] = new_l
+ else:
+ yaml_dict[key] = value
+ else:
+ yaml_dict[key] = value
+ return yaml_dict
+
+ def __repr__(self):
+ """Print all variables."""
+ ret_str = []
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ ret_str.append('{}:'.format(key))
+ child_ret_str = value.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ ret_str.append('{}:'.format(key))
+ for item in value:
+ # Treat as AttrDict above.
+ child_ret_str = item.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ return '\n'.join(ret_str)
+
+
+class Config(AttrDict):
+ r"""Configuration class. This should include every human specifiable
+ hyperparameter values for your training."""
+
+ def __init__(self, filename=None, args=None, verbose=False, is_train=True):
+ super(Config, self).__init__()
+ # Set default parameters.
+ # Logging.
+
+ large_number = 1000000000
+ self.snapshot_save_iter = large_number
+ self.snapshot_save_epoch = large_number
+ self.snapshot_save_start_iter = 0
+ self.snapshot_save_start_epoch = 0
+ self.image_save_iter = large_number
+ self.eval_epoch = large_number
+ self.start_eval_epoch = large_number
+ self.eval_epoch = large_number
+ self.max_epoch = large_number
+ self.max_iter = large_number
+ self.logging_iter = 100
+ self.image_to_tensorboard=False
+ self.which_iter = 0 # args.which_iter
+ self.resume = False
+
+ self.checkpoints_dir = '/Users/shadowcun/Downloads/'
+ self.name = 'face'
+ self.phase = 'train' if is_train else 'test'
+
+ # Networks.
+ self.gen = AttrDict(type='generators.dummy')
+ self.dis = AttrDict(type='discriminators.dummy')
+
+ # Optimizers.
+ self.gen_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ self.dis_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ # Data.
+ self.data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0)
+ self.test_data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0,
+ test=AttrDict(is_lmdb=False,
+ roots='',
+ batch_size=1))
+ self.trainer = AttrDict(
+ model_average=False,
+ model_average_beta=0.9999,
+ model_average_start_iteration=1000,
+ model_average_batch_norm_estimation_iteration=30,
+ model_average_remove_sn=True,
+ image_to_tensorboard=False,
+ hparam_to_tensorboard=False,
+ distributed_data_parallel='pytorch',
+ delay_allreduce=True,
+ gan_relativistic=False,
+ gen_step=1,
+ dis_step=1)
+
+ # # Cudnn.
+ self.cudnn = AttrDict(deterministic=False,
+ benchmark=True)
+
+ # Others.
+ self.pretrained_weight = ''
+ self.inference_args = AttrDict()
+
+
+ # Update with given configurations.
+ assert os.path.exists(filename), 'File {} not exist.'.format(filename)
+ loader = yaml.SafeLoader
+ loader.add_implicit_resolver(
+ u'tag:yaml.org,2002:float',
+ re.compile(u'''^(?:
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+ |[-+]?\\.(?:inf|Inf|INF)
+ |\\.(?:nan|NaN|NAN))$''', re.X),
+ list(u'-+0123456789.'))
+ try:
+ with open(filename, 'r') as f:
+ cfg_dict = yaml.load(f, Loader=loader)
+ except EnvironmentError:
+ print('Please check the file with name of "%s"', filename)
+ recursive_update(self, cfg_dict)
+
+ # Put common opts in both gen and dis.
+ if 'common' in cfg_dict:
+ self.common = AttrDict(**cfg_dict['common'])
+ self.gen.common = self.common
+ self.dis.common = self.common
+
+
+ if verbose:
+ print(' config '.center(80, '-'))
+ print(self.__repr__())
+ print(''.center(80, '-'))
+
+
+def rsetattr(obj, attr, val):
+ """Recursively find object and set value"""
+ pre, _, post = attr.rpartition('.')
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+ """Recursively find object and return value"""
+
+ def _getattr(obj, attr):
+ r"""Get attribute."""
+ return getattr(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def recursive_update(d, u):
+ """Recursively update AttrDict d with AttrDict u"""
+ for key, value in u.items():
+ if isinstance(value, collections.abc.Mapping):
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ d.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ d.__dict__[key] = value
+ else:
+ d.__dict__[key] = value
+ return d
diff --git a/src/facerender/pirender/face_model.py b/src/facerender/pirender/face_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..51692c3f28f08e91d6956efcf528f5be51764721
--- /dev/null
+++ b/src/facerender/pirender/face_model.py
@@ -0,0 +1,178 @@
+import functools
+import torch
+import torch.nn as nn
+from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
+
+def convert_flow_to_deformation(flow):
+ r"""convert flow fields to deformations.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ deformation (tensor): The deformation used for warpping
+ """
+ b,c,h,w = flow.shape
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
+ grid = make_coordinate_grid(flow)
+ deformation = grid + flow_norm.permute(0,2,3,1)
+ return deformation
+
+def make_coordinate_grid(flow):
+ r"""obtain coordinate grid with the same size as the flow filed.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ grid (tensor): The grid with the same size as the input flow
+ """
+ b,c,h,w = flow.shape
+
+ x = torch.arange(w).to(flow)
+ y = torch.arange(h).to(flow)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+ meshed = meshed.expand(b, -1, -1, -1)
+ return meshed
+
+
+def warp_image(source_image, deformation):
+ r"""warp the input image according to the deformation
+
+ Args:
+ source_image (tensor): source images to be warpped
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
+ Returns:
+ output (tensor): the warpped images
+ """
+ _, h_old, w_old, _ = deformation.shape
+ _, _, h, w = source_image.shape
+ if h_old != h or w_old != w:
+ deformation = deformation.permute(0, 3, 1, 2)
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
+ deformation = deformation.permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(source_image, deformation)
+
+
+class FaceGenerator(nn.Module):
+ def __init__(
+ self,
+ mapping_net,
+ warpping_net,
+ editing_net,
+ common
+ ):
+ super(FaceGenerator, self).__init__()
+ self.mapping_net = MappingNet(**mapping_net)
+ self.warpping_net = WarpingNet(**warpping_net, **common)
+ self.editing_net = EditingNet(**editing_net, **common)
+
+ def forward(
+ self,
+ input_image,
+ driving_source,
+ stage=None
+ ):
+ if stage == 'warp':
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ else:
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
+ return output
+
+class MappingNet(nn.Module):
+ def __init__(self, coeff_nc, descriptor_nc, layer):
+ super( MappingNet, self).__init__()
+
+ self.layer = layer
+ nonlinearity = nn.LeakyReLU(0.1)
+
+ self.first = nn.Sequential(
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+ for i in range(layer):
+ net = nn.Sequential(nonlinearity,
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+ setattr(self, 'encoder' + str(i), net)
+
+ self.pooling = nn.AdaptiveAvgPool1d(1)
+ self.output_nc = descriptor_nc
+
+ def forward(self, input_3dmm):
+ out = self.first(input_3dmm)
+ for i in range(self.layer):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out) + out[:,:,3:-3]
+ out = self.pooling(out)
+ return out
+
+class WarpingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ base_nc,
+ max_nc,
+ encoder_layer,
+ decoder_layer,
+ use_spect
+ ):
+ super( WarpingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
+
+ self.descriptor_nc = descriptor_nc
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
+ max_nc, encoder_layer, decoder_layer, **kwargs)
+
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
+ nonlinearity,
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
+
+ self.pool = nn.AdaptiveAvgPool2d(1)
+
+ def forward(self, input_image, descriptor):
+ final_output={}
+ output = self.hourglass(input_image, descriptor)
+ final_output['flow_field'] = self.flow_out(output)
+
+ deformation = convert_flow_to_deformation(final_output['flow_field'])
+ final_output['warp_image'] = warp_image(input_image, deformation)
+ return final_output
+
+
+class EditingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ layer,
+ base_nc,
+ max_nc,
+ num_res_blocks,
+ use_spect):
+ super(EditingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
+ self.descriptor_nc = descriptor_nc
+
+ # encoder part
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
+
+ def forward(self, input_image, warp_image, descriptor):
+ x = torch.cat([input_image, warp_image], 1)
+ x = self.encoder(x)
+ gen_image = self.decoder(x, descriptor)
+ return gen_image
diff --git a/src/facerender/pirender_animate.py b/src/facerender/pirender_animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..406c081d77ff7b6cb9c0f3ed3721408e0d5657be
--- /dev/null
+++ b/src/facerender/pirender_animate.py
@@ -0,0 +1,130 @@
+import os
+import cv2
+from tqdm import tqdm
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+
+from src.facerender.pirender.config import Config
+from src.facerender.pirender.face_model import FaceGenerator
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+class AnimateFromCoeff_PIRender():
+
+ def __init__(self, sadtalker_path, device):
+
+ opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False)
+ opt.device = device
+ self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device)
+ checkpoint_path = sadtalker_path['pirender_checkpoint']
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
+ self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False)
+ print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path))
+ self.net_G = self.net_G_ema.eval()
+ self.device = device
+
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+ os.remove(enhanced_path)
+
+ os.remove(path)
+ os.remove(new_audio_path)
+
+ return return_path
+
diff --git a/src/facerender/sync_batchnorm/__init__.py b/src/facerender/sync_batchnorm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf
--- /dev/null
+++ b/src/facerender/sync_batchnorm/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# File : __init__.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
+from .replicate import DataParallelWithCallback, patch_replication_callback
diff --git a/src/facerender/sync_batchnorm/batchnorm.py b/src/facerender/sync_batchnorm/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f4e763f0366dffa10320116413f8c7181a8aeb1
--- /dev/null
+++ b/src/facerender/sync_batchnorm/batchnorm.py
@@ -0,0 +1,315 @@
+# -*- coding: utf-8 -*-
+# File : batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import collections
+
+import torch
+import torch.nn.functional as F
+
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
+
+from .comm import SyncMaster
+
+__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
+
+
+def _sum_ft(tensor):
+ """sum over the first and last dimention"""
+ return tensor.sum(dim=0).sum(dim=-1)
+
+
+def _unsqueeze_ft(tensor):
+ """add new dementions at the front and the tail"""
+ return tensor.unsqueeze(0).unsqueeze(-1)
+
+
+_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
+
+
+class _SynchronizedBatchNorm(_BatchNorm):
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+
+ self._sync_master = SyncMaster(self._data_parallel_master)
+
+ self._is_parallel = False
+ self._parallel_id = None
+ self._slave_pipe = None
+
+ def forward(self, input):
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+ if not (self._is_parallel and self.training):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
+ # Resize the input to (B, C, -1).
+ input_shape = input.size()
+ input = input.view(input.size(0), self.num_features, -1)
+
+ # Compute the sum and square-sum.
+ sum_size = input.size(0) * input.size(2)
+ input_sum = _sum_ft(input)
+ input_ssum = _sum_ft(input ** 2)
+
+ # Reduce-and-broadcast the statistics.
+ if self._parallel_id == 0:
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
+ else:
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
+
+ # Compute the output.
+ if self.affine:
+ # MJY:: Fuse the multiplication for speed.
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
+ else:
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
+
+ # Reshape it.
+ return output.view(input_shape)
+
+ def __data_parallel_replicate__(self, ctx, copy_id):
+ self._is_parallel = True
+ self._parallel_id = copy_id
+
+ # parallel_id == 0 means master device.
+ if self._parallel_id == 0:
+ ctx.sync_master = self._sync_master
+ else:
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+ def _data_parallel_master(self, intermediates):
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+
+ # Always using same "device order" makes the ReduceAdd operation faster.
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+ to_reduce = [i[1][:2] for i in intermediates]
+ to_reduce = [j for i in to_reduce for j in i] # flatten
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+ sum_size = sum([i[1].sum_size for i in intermediates])
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+ outputs = []
+ for i, rec in enumerate(intermediates):
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+ return outputs
+
+ def _compute_mean_std(self, sum_, ssum, size):
+ """Compute the mean and standard-deviation with sum and square-sum. This method
+ also maintains the moving average on the master device."""
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+ mean = sum_ / size
+ sumvar = ssum - sum_ * mean
+ unbias_var = sumvar / (size - 1)
+ bias_var = sumvar / size
+
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+
+ return mean, bias_var.clamp(self.eps) ** -0.5
+
+
+class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
+ mini-batch.
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of size
+ `batch_size x num_features [x width]`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 2 and input.dim() != 3:
+ raise ValueError('expected 2D or 3D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
+ of 3d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
+ of 4d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
+ or Spatio-temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x depth x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
diff --git a/src/facerender/sync_batchnorm/comm.py b/src/facerender/sync_batchnorm/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b
--- /dev/null
+++ b/src/facerender/sync_batchnorm/comm.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# File : comm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+ def __init__(self):
+ self._result = None
+ self._lock = threading.Lock()
+ self._cond = threading.Condition(self._lock)
+
+ def put(self, result):
+ with self._lock:
+ assert self._result is None, 'Previous result has\'t been fetched.'
+ self._result = result
+ self._cond.notify()
+
+ def get(self):
+ with self._lock:
+ if self._result is None:
+ self._cond.wait()
+
+ res = self._result
+ self._result = None
+ return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+ """Pipe for master-slave communication."""
+
+ def run_slave(self, msg):
+ self.queue.put((self.identifier, msg))
+ ret = self.result.get()
+ self.queue.put(True)
+ return ret
+
+
+class SyncMaster(object):
+ """An abstract `SyncMaster` object.
+
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+ and passed to a registered callback.
+ - After receiving the messages, the master device should gather the information and determine to message passed
+ back to each slave devices.
+ """
+
+ def __init__(self, master_callback):
+ """
+
+ Args:
+ master_callback: a callback to be invoked after having collected messages from slave devices.
+ """
+ self._master_callback = master_callback
+ self._queue = queue.Queue()
+ self._registry = collections.OrderedDict()
+ self._activated = False
+
+ def __getstate__(self):
+ return {'master_callback': self._master_callback}
+
+ def __setstate__(self, state):
+ self.__init__(state['master_callback'])
+
+ def register_slave(self, identifier):
+ """
+ Register an slave device.
+
+ Args:
+ identifier: an identifier, usually is the device id.
+
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+ """
+ if self._activated:
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
+ self._activated = False
+ self._registry.clear()
+ future = FutureResult()
+ self._registry[identifier] = _MasterRegistry(future)
+ return SlavePipe(identifier, self._queue, future)
+
+ def run_master(self, master_msg):
+ """
+ Main entry for the master device in each forward pass.
+ The messages were first collected from each devices (including the master device), and then
+ an callback will be invoked to compute the message to be sent back to each devices
+ (including the master device).
+
+ Args:
+ master_msg: the message that the master want to send to itself. This will be placed as the first
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+ Returns: the message to be sent back to the master device.
+
+ """
+ self._activated = True
+
+ intermediates = [(0, master_msg)]
+ for i in range(self.nr_slaves):
+ intermediates.append(self._queue.get())
+
+ results = self._master_callback(intermediates)
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+ for i, res in results:
+ if i == 0:
+ continue
+ self._registry[i].result.put(res)
+
+ for i in range(self.nr_slaves):
+ assert self._queue.get() is True
+
+ return results[0][1]
+
+ @property
+ def nr_slaves(self):
+ return len(self._registry)
diff --git a/src/facerender/sync_batchnorm/replicate.py b/src/facerender/sync_batchnorm/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06
--- /dev/null
+++ b/src/facerender/sync_batchnorm/replicate.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# File : replicate.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import functools
+
+from torch.nn.parallel.data_parallel import DataParallel
+
+__all__ = [
+ 'CallbackContext',
+ 'execute_replication_callbacks',
+ 'DataParallelWithCallback',
+ 'patch_replication_callback'
+]
+
+
+class CallbackContext(object):
+ pass
+
+
+def execute_replication_callbacks(modules):
+ """
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
+
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
+ (shared among multiple copies of this module on different devices).
+ Through this context, different copies can share some information.
+
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
+ of any slave copies.
+ """
+ master_copy = modules[0]
+ nr_modules = len(list(master_copy.modules()))
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+ for i, module in enumerate(modules):
+ for j, m in enumerate(module.modules()):
+ if hasattr(m, '__data_parallel_replicate__'):
+ m.__data_parallel_replicate__(ctxs[j], i)
+
+
+class DataParallelWithCallback(DataParallel):
+ """
+ Data Parallel with a replication callback.
+
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
+ original `replicate` function.
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ # sync_bn.__data_parallel_replicate__ will be invoked.
+ """
+
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+
+def patch_replication_callback(data_parallel):
+ """
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
+ Useful when you have customized `DataParallel` implementation.
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+ > patch_replication_callback(sync_bn)
+ # this is equivalent to
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ """
+
+ assert isinstance(data_parallel, DataParallel)
+
+ old_replicate = data_parallel.replicate
+
+ @functools.wraps(old_replicate)
+ def new_replicate(module, device_ids):
+ modules = old_replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+ data_parallel.replicate = new_replicate
diff --git a/src/facerender/sync_batchnorm/unittest.py b/src/facerender/sync_batchnorm/unittest.py
new file mode 100644
index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524
--- /dev/null
+++ b/src/facerender/sync_batchnorm/unittest.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# File : unittest.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import unittest
+
+import numpy as np
+from torch.autograd import Variable
+
+
+def as_numpy(v):
+ if isinstance(v, Variable):
+ v = v.data
+ return v.cpu().numpy()
+
+
+class TorchTestCase(unittest.TestCase):
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
+ npa, npb = as_numpy(a), as_numpy(b)
+ self.assertTrue(
+ np.allclose(npa, npb, atol=atol),
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
+ )
diff --git a/src/generate_batch.py b/src/generate_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..95f21526feea846977707e97394132d43225c02a
--- /dev/null
+++ b/src/generate_batch.py
@@ -0,0 +1,120 @@
+import os
+
+from tqdm import tqdm
+import torch
+import numpy as np
+import random
+import scipy.io as scio
+import src.utils.audio as audio
+
+def crop_pad_audio(wav, audio_length):
+ if len(wav) > audio_length:
+ wav = wav[:audio_length]
+ elif len(wav) < audio_length:
+ wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
+ return wav
+
+def parse_audio_length(audio_length, sr, fps):
+ bit_per_frames = sr / fps
+
+ num_frames = int(audio_length / bit_per_frames)
+ audio_length = int(num_frames * bit_per_frames)
+
+ return audio_length, num_frames
+
+def generate_blink_seq(num_frames):
+ ratio = np.zeros((num_frames,1))
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = 80
+ if frame_id+start+9<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
+ frame_id = frame_id+start+9
+ else:
+ break
+ return ratio
+
+def generate_blink_seq_randomly(num_frames):
+ ratio = np.zeros((num_frames,1))
+ if num_frames<=20:
+ return ratio
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70)))
+ if frame_id+start+5<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
+ frame_id = frame_id+start+5
+ else:
+ break
+ return ratio
+
+def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
+
+ syncnet_mel_step_size = 16
+ fps = 25
+
+ pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+
+
+ if idlemode:
+ num_frames = int(length_of_audio * 25)
+ indiv_mels = np.zeros((num_frames, 80, 16))
+ else:
+ wav = audio.load_wav(audio_path, 16000)
+ wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
+ wav = crop_pad_audio(wav, wav_length)
+ orig_mel = audio.melspectrogram(wav).T
+ spec = orig_mel.copy() # nframes 80
+ indiv_mels = []
+
+ for i in tqdm(range(num_frames), 'mel:'):
+ start_frame_num = i-2
+ start_idx = int(80. * (start_frame_num / float(fps)))
+ end_idx = start_idx + syncnet_mel_step_size
+ seq = list(range(start_idx, end_idx))
+ seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
+ m = spec[seq, :]
+ indiv_mels.append(m.T)
+ indiv_mels = np.asarray(indiv_mels) # T 80 16
+
+ ratio = generate_blink_seq_randomly(num_frames) # T
+ source_semantics_path = first_coeff_path
+ source_semantics_dict = scio.loadmat(source_semantics_path)
+ ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
+ ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
+
+ if ref_eyeblink_coeff_path is not None:
+ ratio[:num_frames] = 0
+ refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
+ refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
+ refeyeblink_num_frames = refeyeblink_coeff.shape[0]
+ if refeyeblink_num_frames frame_num:
+ new_degree_list = new_degree_list[:frame_num]
+ elif len(new_degree_list) < frame_num:
+ for _ in range(frame_num-len(new_degree_list)):
+ new_degree_list.append(new_degree_list[-1])
+ print(len(new_degree_list))
+ print(frame_num)
+
+ remainder = frame_num%batch_size
+ if remainder!=0:
+ for _ in range(batch_size-remainder):
+ new_degree_list.append(new_degree_list[-1])
+ new_degree_np = np.array(new_degree_list).reshape(batch_size, -1)
+ return new_degree_np
+
diff --git a/src/gradio_demo.py b/src/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1d2619fd9a67b37bea55bc91776afbcb3e50558
--- /dev/null
+++ b/src/gradio_demo.py
@@ -0,0 +1,170 @@
+import torch, uuid
+import os, sys, shutil, platform
+from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
+from src.utils.preprocess import CropAndExtract
+from src.test_audio2coeff import Audio2Coeff
+from src.facerender.animate import AnimateFromCoeff
+from src.generate_batch import get_data
+from src.generate_facerender_batch import get_facerender_data
+
+from src.utils.init_path import init_path
+
+from pydub import AudioSegment
+
+
+def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
+ mp3_file = AudioSegment.from_file(file=mp3_filename)
+ mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
+
+
+class SadTalker():
+
+ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
+
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif platform.system() == 'Darwin': # macos
+ device = "mps"
+ else:
+ device = "cpu"
+
+ self.device = device
+
+ os.environ['TORCH_HOME']= checkpoint_path
+
+ self.checkpoint_path = checkpoint_path
+ self.config_path = config_path
+
+
+ def test(self, source_image, driven_audio, preprocess='crop',
+ still_mode=False, use_enhancer=False, batch_size=1, size=256,
+ pose_style = 0,
+ facerender='facevid2vid',
+ exp_scale=1.0,
+ use_ref_video = False,
+ ref_video = None,
+ ref_info = None,
+ use_idle_mode = False,
+ length_of_audio = 0, use_blink=True,
+ result_dir='./results/'):
+
+ self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
+ print(self.sadtalker_paths)
+
+ self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
+ self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
+
+ if facerender == 'facevid2vid' and self.device != 'mps':
+ self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
+ elif facerender == 'pirender' or self.device == 'mps':
+ self.animate_from_coeff = AnimateFromCoeff_PIRender(self.sadtalker_paths, self.device)
+ facerender = 'pirender'
+ else:
+ raise(RuntimeError('Unknown model: {}'.format(facerender)))
+
+
+ time_tag = str(uuid.uuid4())
+ save_dir = os.path.join(result_dir, time_tag)
+ os.makedirs(save_dir, exist_ok=True)
+
+ input_dir = os.path.join(save_dir, 'input')
+ os.makedirs(input_dir, exist_ok=True)
+
+ print(source_image)
+ pic_path = os.path.join(input_dir, os.path.basename(source_image))
+ shutil.move(source_image, input_dir)
+
+ if driven_audio is not None and os.path.isfile(driven_audio):
+ audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
+
+ #### mp3 to wav
+ if '.mp3' in audio_path:
+ mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
+ audio_path = audio_path.replace('.mp3', '.wav')
+ else:
+ shutil.move(driven_audio, input_dir)
+
+ elif use_idle_mode:
+ audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
+ from pydub import AudioSegment
+ one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
+ one_sec_segment.export(audio_path, format="wav")
+ else:
+ print(use_ref_video, ref_info)
+ assert use_ref_video == True and ref_info == 'all'
+
+ if use_ref_video and ref_info == 'all': # full ref mode
+ ref_video_videoname = os.path.basename(ref_video)
+ audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
+ print('new audiopath:',audio_path)
+ # if ref_video contains audio, set the audio from ref_video.
+ cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
+ os.system(cmd)
+
+ os.makedirs(save_dir, exist_ok=True)
+
+ #crop image and extract 3dmm from image
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
+ os.makedirs(first_frame_dir, exist_ok=True)
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
+
+ if first_coeff_path is None:
+ raise AttributeError("No face is detected")
+
+ if use_ref_video:
+ print('using ref video for genreation')
+ ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
+ ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
+ os.makedirs(ref_video_frame_dir, exist_ok=True)
+ print('3DMM Extraction for the reference video providing pose')
+ ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
+ else:
+ ref_video_coeff_path = None
+
+ if use_ref_video:
+ if ref_info == 'pose':
+ ref_pose_coeff_path = ref_video_coeff_path
+ ref_eyeblink_coeff_path = None
+ elif ref_info == 'blink':
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = ref_video_coeff_path
+ elif ref_info == 'pose+blink':
+ ref_pose_coeff_path = ref_video_coeff_path
+ ref_eyeblink_coeff_path = ref_video_coeff_path
+ elif ref_info == 'all':
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = None
+ else:
+ raise('error in refinfo')
+ else:
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = None
+
+ #audio2ceoff
+ if use_ref_video and ref_info == 'all':
+ coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+ else:
+ batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, \
+ idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
+ coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+
+ #coeff2video
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, \
+ preprocess=preprocess, size=size, expression_scale = exp_scale, facemodel=facerender)
+ return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
+ video_name = data['video_name']
+ print(f'The generated video is named {video_name} in {save_dir}')
+
+ del self.preprocess_model
+ del self.audio_to_coeff
+ del self.animate_from_coeff
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ import gc; gc.collect()
+
+ return return_path
+
+
\ No newline at end of file
diff --git a/src/test_audio2coeff.py b/src/test_audio2coeff.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf19f494e2127b4ae9d6074b172fddb694d6e34
--- /dev/null
+++ b/src/test_audio2coeff.py
@@ -0,0 +1,123 @@
+import os
+import torch
+import numpy as np
+from scipy.io import savemat, loadmat
+from yacs.config import CfgNode as CN
+from scipy.signal import savgol_filter
+
+import safetensors
+import safetensors.torch
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2
+from src.audio2exp_models.audio2exp import Audio2Exp
+from src.utils.safetensor_helper import load_x_from_safetensor
+
+def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if model is not None:
+ model.load_state_dict(checkpoint['model'])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+
+ return checkpoint['epoch']
+
+class Audio2Coeff():
+
+ def __init__(self, sadtalker_path, device):
+ #load config
+ fcfg_pose = open(sadtalker_path['audio2pose_yaml_path'])
+ cfg_pose = CN.load_cfg(fcfg_pose)
+ cfg_pose.freeze()
+ fcfg_exp = open(sadtalker_path['audio2exp_yaml_path'])
+ cfg_exp = CN.load_cfg(fcfg_exp)
+ cfg_exp.freeze()
+
+ # load audio2pose_model
+ self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device)
+ self.audio2pose_model = self.audio2pose_model.to(device)
+ self.audio2pose_model.eval()
+ for param in self.audio2pose_model.parameters():
+ param.requires_grad = False
+
+ try:
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose'))
+ else:
+ load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device)
+ except:
+ raise Exception("Failed in loading audio2pose_checkpoint")
+
+ # load audio2exp_model
+ netG = SimpleWrapperV2()
+ netG = netG.to(device)
+ for param in netG.parameters():
+ netG.requires_grad = False
+ netG.eval()
+ try:
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp'))
+ else:
+ load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device)
+ except:
+ raise Exception("Failed in loading audio2exp_checkpoint")
+ self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False)
+ self.audio2exp_model = self.audio2exp_model.to(device)
+ for param in self.audio2exp_model.parameters():
+ param.requires_grad = False
+ self.audio2exp_model.eval()
+
+ self.device = device
+
+ def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None):
+
+ with torch.no_grad():
+ #test
+ results_dict_exp= self.audio2exp_model.test(batch)
+ exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64
+
+ #for class_id in range(1):
+ #class_id = 0#(i+10)%45
+ #class_id = random.randint(0,46) #46 styles can be selected
+ batch['class'] = torch.LongTensor([pose_style]).to(self.device)
+ results_dict_pose = self.audio2pose_model.test(batch)
+ pose_pred = results_dict_pose['pose_pred'] #bs T 6
+
+ pose_len = pose_pred.shape[1]
+ if pose_len<13:
+ pose_len = int((pose_len-1)/2)*2+1
+ pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device)
+ else:
+ pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device)
+
+ coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70
+
+ coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy()
+
+ if ref_pose_coeff_path is not None:
+ coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path)
+
+ savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])),
+ {'coeff_3dmm': coeffs_pred_numpy})
+
+ return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name']))
+
+ def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
+ num_frames = coeffs_pred_numpy.shape[0]
+ refpose_coeff_dict = loadmat(ref_pose_coeff_path)
+ refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70]
+ refpose_num_frames = refpose_coeff.shape[0]
+ if refpose_num_frames= 0
+ if hp.symmetric_mels:
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
+ else:
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
+
+def _denormalize(D):
+ if hp.allow_clipping_in_normalization:
+ if hp.symmetric_mels:
+ return (((np.clip(D, -hp.max_abs_value,
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ + hp.min_level_db)
+ else:
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
+
+ if hp.symmetric_mels:
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
+ else:
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
diff --git a/src/utils/croper.py b/src/utils/croper.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9a0ac58f97afdc95d40f2a400272b11fe38093
--- /dev/null
+++ b/src/utils/croper.py
@@ -0,0 +1,144 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import scipy
+import numpy as np
+from PIL import Image
+import torch
+from tqdm import tqdm
+from itertools import cycle
+
+from src.face3d.extract_kp_videos_safe import KeypointExtractor
+from facexlib.alignment import landmark_98_to_68
+
+import numpy as np
+from PIL import Image
+
+class Preprocesser:
+ def __init__(self, device='cuda'):
+ self.predictor = KeypointExtractor(device)
+
+ def get_landmark(self, img_np):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ with torch.no_grad():
+ dets = self.predictor.det_net.detect_faces(img_np, 0.97)
+
+ if len(dets) == 0:
+ return None
+ det = dets[0]
+
+ img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
+ lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0]
+
+ #### keypoints to the original location
+ lm[:,0] += int(det[0])
+ lm[:,1] += int(det[1])
+
+ return lm
+
+ def align_face(self, img, lm, output_size=1024):
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+ lm_chin = lm[0: 17] # left-right
+ lm_eyebrow_left = lm[17: 22] # left-right
+ lm_eyebrow_right = lm[22: 27] # left-right
+ lm_nose = lm[27: 31] # top-down
+ lm_nostrils = lm[31: 36] # top-down
+ lm_eye_left = lm[36: 42] # left-clockwise
+ lm_eye_right = lm[42: 48] # left-clockwise
+ lm_mouth_outer = lm[48: 60] # left-clockwise
+ lm_mouth_inner = lm[60: 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference
+ x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点
+ qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍
+
+ # Shrink.
+ # 如果计算出的四边形太大了,就按比例缩小它
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+ else:
+ rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ # img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+ # if enable_padding and max(pad) > border - 4:
+ # pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # h, w, _ = img.shape
+ # y, x, _ = np.ogrid[:h, :w, :1]
+ # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+ # blur = qsize * 0.02
+ # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ # quad += pad[:2]
+
+ # Transform.
+ quad = (quad + 0.5).flatten()
+ lx = max(min(quad[0], quad[2]), 0)
+ ly = max(min(quad[1], quad[7]), 0)
+ rx = min(max(quad[4], quad[6]), img.size[0])
+ ry = min(max(quad[3], quad[5]), img.size[0])
+
+ # Save aligned image.
+ return rsize, crop, [lx, ly, rx, ry]
+
+ def crop(self, img_np_list, still=False, xsize=512): # first frame for all video
+ img_np = img_np_list[0]
+ lm = self.get_landmark(img_np)
+
+ if lm is None:
+ raise 'can not detect the landmark from source image'
+ rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ for _i in range(len(img_np_list)):
+ _inp = img_np_list[_i]
+ _inp = cv2.resize(_inp, (rsize[0], rsize[1]))
+ _inp = _inp[cly:cry, clx:crx]
+ if not still:
+ _inp = _inp[ly:ry, lx:rx]
+ img_np_list[_i] = _inp
+ return img_np_list, crop, quad
+
diff --git a/src/utils/face_enhancer.py b/src/utils/face_enhancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..15851a15966c963d7bd04f35eebdaa6b22a3d966
--- /dev/null
+++ b/src/utils/face_enhancer.py
@@ -0,0 +1,123 @@
+import os
+import torch
+
+from gfpgan import GFPGANer
+
+from tqdm import tqdm
+
+from src.utils.videoio import load_video_to_cv2
+
+import cv2
+
+
+class GeneratorWithLen(object):
+ """ From https://stackoverflow.com/a/7460929 """
+
+ def __init__(self, gen, length):
+ self.gen = gen
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ return self.gen
+
+def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
+ return list(gen)
+
+def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
+ """ Provide a generator with a __len__ method so that it can passed to functions that
+ call len()"""
+
+ if os.path.isfile(images): # handle video to images
+ # TODO: Create a generator version of load_video_to_cv2
+ images = load_video_to_cv2(images)
+
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
+ gen_with_len = GeneratorWithLen(gen, len(images))
+ return gen_with_len
+
+def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
+ """ Provide a generator function so that all of the enhanced images don't need
+ to be stored in memory at the same time. This can save tons of RAM compared to
+ the enhancer function. """
+
+ print('face enhancer....')
+ if not isinstance(images, list) and os.path.isfile(images): # handle video to images
+ images = load_video_to_cv2(images)
+
+ # ------------------------ set up GFPGAN restorer ------------------------
+ if method == 'gfpgan':
+ arch = 'clean'
+ channel_multiplier = 2
+ model_name = 'GFPGANv1.4'
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
+ elif method == 'RestoreFormer':
+ arch = 'RestoreFormer'
+ channel_multiplier = 2
+ model_name = 'RestoreFormer'
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
+ elif method == 'codeformer': # TODO:
+ arch = 'CodeFormer'
+ channel_multiplier = 2
+ model_name = 'CodeFormer'
+ url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+ else:
+ raise ValueError(f'Wrong model version {method}.')
+
+
+ # ------------------------ set up background upsampler ------------------------
+ if bg_upsampler == 'realesrgan':
+ if not torch.cuda.is_available(): # CPU
+ import warnings
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
+ 'If you really want to use it, please modify the corresponding codes.')
+ bg_upsampler = None
+ else:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from realesrgan import RealESRGANer
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ bg_upsampler = RealESRGANer(
+ scale=2,
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
+ model=model,
+ tile=400,
+ tile_pad=10,
+ pre_pad=0,
+ half=True) # need to set False in CPU mode
+ else:
+ bg_upsampler = None
+
+ # determine model paths
+ model_path = os.path.join('gfpgan/weights', model_name + '.pth')
+
+ if not os.path.isfile(model_path):
+ model_path = os.path.join('checkpoints', model_name + '.pth')
+
+ if not os.path.isfile(model_path):
+ # download pre-trained models from url
+ model_path = url
+
+ restorer = GFPGANer(
+ model_path=model_path,
+ upscale=2,
+ arch=arch,
+ channel_multiplier=channel_multiplier,
+ bg_upsampler=bg_upsampler)
+
+ # ------------------------ restore ------------------------
+ for idx in tqdm(range(len(images)), 'Face Enhancer:'):
+
+ img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
+
+ # restore faces and background if necessary
+ cropped_faces, restored_faces, r_img = restorer.enhance(
+ img,
+ has_aligned=False,
+ only_center_face=False,
+ paste_back=True)
+
+ r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
+ yield r_img
diff --git a/src/utils/hparams.py b/src/utils/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..743c5c7d5a5a9e686f1ccd6fb3c2fb5cb382d62b
--- /dev/null
+++ b/src/utils/hparams.py
@@ -0,0 +1,160 @@
+from glob import glob
+import os
+
+class HParams:
+ def __init__(self, **kwargs):
+ self.data = {}
+
+ for key, value in kwargs.items():
+ self.data[key] = value
+
+ def __getattr__(self, key):
+ if key not in self.data:
+ raise AttributeError("'HParams' object has no attribute %s" % key)
+ return self.data[key]
+
+ def set_hparam(self, key, value):
+ self.data[key] = value
+
+
+# Default hyperparameters
+hparams = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=16,
+ initial_learning_rate=1e-4,
+ nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=20,
+ checkpoint_interval=3000,
+ eval_interval=3000,
+ writer_interval=300,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=1000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+
+# Default hyperparameters
+hparamsdebug = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=2,
+ initial_learning_rate=1e-3,
+ nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=0,
+ checkpoint_interval=10000,
+ eval_interval=10,
+ writer_interval=5,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=10000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+def hparams_debug_string():
+ values = hparams.values()
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
+ return "Hyperparameters:\n" + "\n".join(hp)
diff --git a/src/utils/init_path.py b/src/utils/init_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..18ca81eb81f564f44fd376667168807e4e976a36
--- /dev/null
+++ b/src/utils/init_path.py
@@ -0,0 +1,49 @@
+import os
+import glob
+
+def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):
+
+ if old_version:
+ #### load all the checkpoint of `pth`
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ use_safetensor = False
+ elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))):
+ print('using safetensor as default')
+ sadtalker_paths = {
+ "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'),
+ }
+ use_safetensor = True
+ else:
+ print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!")
+ use_safetensor = False
+
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting'
+ sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml')
+ sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml')
+ sadtalker_paths['pirender_yaml_path'] = os.path.join(config_dir, 'facerender_pirender.yaml')
+ sadtalker_paths['pirender_checkpoint'] = os.path.join(checkpoint_dir, 'epoch_00190_iteration_000400000_checkpoint.pt')
+ sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml')
+
+ if 'full' in preprocess:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml')
+ else:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml')
+
+ return sadtalker_paths
\ No newline at end of file
diff --git a/src/utils/model2safetensor.py b/src/utils/model2safetensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c485000d43ba9c230a0bc64ce8aeaaec6e2b29
--- /dev/null
+++ b/src/utils/model2safetensor.py
@@ -0,0 +1,141 @@
+import torch
+import yaml
+import os
+
+import safetensors
+from safetensors.torch import save_file
+from yacs.config import CfgNode as CN
+import sys
+
+sys.path.append('/apdcephfs/private_shadowcun/SadTalker')
+
+from src.face3d.models import networks
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2
+from src.test_audio2coeff import load_cpk
+
+size = 256
+############ face vid2vid
+config_path = os.path.join('src', 'config', 'facerender.yaml')
+current_root_path = '.'
+
+path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
+net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='')
+checkpoint = torch.load(path_of_net_recon_model, map_location='cpu')
+net_recon.load_state_dict(checkpoint['net_recon'])
+
+with open(config_path) as f:
+ config = yaml.safe_load(f)
+
+generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+ **config['model_params']['common_params'])
+he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+ **config['model_params']['common_params'])
+mapping = MappingNet(**config['model_params']['mapping_params'])
+
+def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None,
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
+ optimizer_discriminator=None, optimizer_kp_detector=None,
+ optimizer_he_estimator=None, device="cpu"):
+
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if generator is not None:
+ generator.load_state_dict(checkpoint['generator'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if he_estimator is not None:
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
+ if discriminator is not None:
+ try:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ except:
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+ if optimizer_generator is not None:
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+ if optimizer_discriminator is not None:
+ try:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+ except RuntimeError as e:
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+ if optimizer_kp_detector is not None:
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+ if optimizer_he_estimator is not None:
+ optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+ return checkpoint['epoch']
+
+
+def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
+free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar'
+load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+
+wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
+
+audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
+audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
+
+audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
+audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
+
+fcfg_pose = open(audio2pose_yaml_path)
+cfg_pose = CN.load_cfg(fcfg_pose)
+cfg_pose.freeze()
+audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint)
+audio2pose_model.eval()
+load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu')
+
+# load audio2exp_model
+netG = SimpleWrapperV2()
+netG.eval()
+load_cpk(audio2exp_checkpoint, model=netG, device='cpu')
+
+class SadTalker(torch.nn.Module):
+ def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon):
+ super(SadTalker, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.audio2exp = netG
+ self.audio2pose = audio2pose
+ self.face_3drecon = face_3drecon
+
+
+model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon)
+
+# here, we want to convert it to safetensor
+save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors")
+
+### test
+load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None)
\ No newline at end of file
diff --git a/src/utils/paste_pic.py b/src/utils/paste_pic.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9989e21e48e64f620f9b148e65fdfe806c53b14
--- /dev/null
+++ b/src/utils/paste_pic.py
@@ -0,0 +1,69 @@
+import cv2, os
+import numpy as np
+from tqdm import tqdm
+import uuid
+
+from src.utils.videoio import save_video_with_watermark
+
+def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = cv2.imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ video_stream = cv2.VideoCapture(video_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ crop_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ crop_frames.append(frame)
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_path = str(uuid.uuid4())+'.mp4'
+ out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+ for crop_frame in tqdm(crop_frames, 'seamlessClone:'):
+ p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1))
+
+ mask = 255*np.ones(p.shape, p.dtype)
+ location = ((ox1+ox2) // 2, (oy1+oy2) // 2)
+ gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE)
+ out_tmp.write(gen_img)
+
+ out_tmp.release()
+
+ save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False)
+ os.remove(tmp_path)
diff --git a/src/utils/preprocess.py b/src/utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f784e6c3d8562e1db1bbd850b9f01843cee3c97
--- /dev/null
+++ b/src/utils/preprocess.py
@@ -0,0 +1,170 @@
+import numpy as np
+import cv2, os, sys, torch
+from tqdm import tqdm
+from PIL import Image
+
+# 3dmm extraction
+import safetensors
+import safetensors.torch
+from src.face3d.util.preprocess import align_img
+from src.face3d.util.load_mats import load_lm3d
+from src.face3d.models import networks
+
+from scipy.io import loadmat, savemat
+from src.utils.croper import Preprocesser
+
+
+import warnings
+
+from src.utils.safetensor_helper import load_x_from_safetensor
+warnings.filterwarnings("ignore")
+
+def split_coeff(coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+
+
+class CropAndExtract():
+ def __init__(self, sadtalker_path, device):
+
+ self.propress = Preprocesser(device)
+ self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
+
+ if sadtalker_path['use_safetensor']:
+ checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon'))
+ else:
+ checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device))
+ self.net_recon.load_state_dict(checkpoint['net_recon'])
+
+ self.net_recon.eval()
+ self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting'])
+ self.device = device
+
+ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256):
+
+ pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]
+
+ landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt')
+ coeff_path = os.path.join(save_dir, pic_name+'.mat')
+ png_path = os.path.join(save_dir, pic_name+'.png')
+
+ #load input
+ if not os.path.isfile(input_path):
+ raise ValueError('input_path must be a valid path to video/image file')
+ elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_frames = [cv2.imread(input_path)]
+ fps = 25
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(frame)
+ if source_image_flag:
+ break
+
+ x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
+
+ #### crop images as the
+ if 'crop' in crop_or_resize.lower(): # default crop
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ elif 'full' in crop_or_resize.lower():
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ else: # resize mode
+ oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1]
+ crop_info = ((ox2 - ox1, oy2 - oy1), None, None)
+
+ frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]
+ if len(frames_pil) == 0:
+ print('No face is detected in the input file')
+ return None, None
+
+ # save crop info
+ for frame in frames_pil:
+ cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+
+ # 2. get the landmark according to the detected face.
+ if not os.path.isfile(landmarks_path):
+ lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)
+ else:
+ print(' Using saved landmarks.')
+ lm = np.loadtxt(landmarks_path).astype(np.float32)
+ lm = lm.reshape([len(x_full_frames), -1, 2])
+
+ if not os.path.isfile(coeff_path):
+ # load 3dmm paramter generator from Deep3DFaceRecon_pytorch
+ video_coeffs, full_coeffs = [], []
+ for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'):
+ frame = frames_pil[idx]
+ W,H = frame.size
+ lm1 = lm[idx].reshape([-1, 2])
+
+ if np.mean(lm1) == -1:
+ lm1 = (self.lm3d_std[:, :2]+1)/2.
+ lm1 = np.concatenate(
+ [lm1[:, :1]*W, lm1[:, 1:2]*H], 1
+ )
+ else:
+ lm1[:, -1] = H - 1 - lm1[:, -1]
+
+ trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)
+
+ trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
+ im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0)
+
+ with torch.no_grad():
+ full_coeff = self.net_recon(im_t)
+ coeffs = split_coeff(full_coeff)
+
+ pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
+
+ pred_coeff = np.concatenate([
+ pred_coeff['exp'],
+ pred_coeff['angle'],
+ pred_coeff['trans'],
+ trans_params[2:][None],
+ ], 1)
+ video_coeffs.append(pred_coeff)
+ full_coeffs.append(full_coeff.cpu().numpy())
+
+ semantic_npy = np.array(video_coeffs)[:,0]
+
+ savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]})
+
+ return coeff_path, png_path, crop_info
diff --git a/src/utils/safetensor_helper.py b/src/utils/safetensor_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cdbdd21e4ed656dfe2d31a57360afb3e96480b3
--- /dev/null
+++ b/src/utils/safetensor_helper.py
@@ -0,0 +1,8 @@
+
+
+def load_x_from_safetensor(checkpoint, key):
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if key in k:
+ x_generator[k.replace(key+'.', '')] = v
+ return x_generator
\ No newline at end of file
diff --git a/src/utils/text2speech.py b/src/utils/text2speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..00d165b6cc7774fd200929aafa0ff3b15916111e
--- /dev/null
+++ b/src/utils/text2speech.py
@@ -0,0 +1,20 @@
+import os
+import tempfile
+from TTS.api import TTS
+
+
+class TTSTalker():
+ def __init__(self) -> None:
+ model_name = TTS.list_models()[0]
+ self.tts = TTS(model_name)
+
+ def test(self, text, language='en'):
+
+ tempf = tempfile.NamedTemporaryFile(
+ delete = False,
+ suffix = ('.'+'wav'),
+ )
+
+ self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name)
+
+ return tempf.name
\ No newline at end of file
diff --git a/src/utils/videoio.py b/src/utils/videoio.py
new file mode 100644
index 0000000000000000000000000000000000000000..d16ee667713a16e3f9644fcc3cb3e023bc2c9102
--- /dev/null
+++ b/src/utils/videoio.py
@@ -0,0 +1,41 @@
+import shutil
+import uuid
+
+import os
+
+import cv2
+
+def load_video_to_cv2(input_path):
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ return full_frames
+
+def save_video_with_watermark(video, audio, save_path, watermark=False):
+ temp_file = str(uuid.uuid4())+'.mp4'
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec mpeg4 "%s"' % (video, audio, temp_file)
+ os.system(cmd)
+
+ if watermark is False:
+ shutil.move(temp_file, save_path)
+ else:
+ # watermark
+ try:
+ ##### check if stable-diffusion-webui
+ import webui
+ from modules import paths
+ watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png"
+ except:
+ # get the root path of sadtalker.
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
+
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
+ os.system(cmd)
+ os.remove(temp_file)
\ No newline at end of file
diff --git a/tts_voice.py b/tts_voice.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ee194c252f82ada41ccc14f33adb592e1a00985
--- /dev/null
+++ b/tts_voice.py
@@ -0,0 +1,26 @@
+tts_order_voice = {'英语 (美国)-Jenny-女': 'en-US-JennyNeural',
+ '英语 (美国)-Guy-男': 'en-US-GuyNeural',
+ '英语 (美国)-Ana-女': 'en-US-AnaNeural',
+ '英语 (美国)-Aria-女': 'en-US-AriaNeural',
+ '英语 (美国)-Christopher-男': 'en-US-ChristopherNeural',
+ '英语 (美国)-Eric-男': 'en-US-EricNeural',
+ '英语 (美国)-Michelle-女': 'en-US-MichelleNeural',
+ '英语 (美国)-Roger-男': 'en-US-RogerNeural',
+ '韩语 (韩国)-Sun-Hi-女': 'ko-KR-SunHiNeural',
+ '韩语 (韩国)-InJoon-男': 'ko-KR-InJoonNeural',
+ '日语 (日本)-Nanami-女': 'ja-JP-NanamiNeural',
+ '日语 (日本)-Keita-男': 'ja-JP-KeitaNeural',
+ '普通话 (中国大陆)-Xiaoxiao-女': 'zh-CN-XiaoxiaoNeural',
+ '普通话 (中国大陆)-Yunyang-男': 'zh-CN-YunyangNeural',
+ '普通话 (中国大陆)-Yunxi-男': 'zh-CN-YunxiNeural',
+ '普通话 (中国大陆)-Xiaoyi-女': 'zh-CN-XiaoyiNeural',
+ '普通话 (中国大陆)-Yunjian-男': 'zh-CN-YunjianNeural',
+ '普通话 (中国大陆)-Yunxia-男': 'zh-CN-YunxiaNeural',
+ '东北话 (中国大陆)-Xiaobei-女': 'zh-CN-liaoning-XiaobeiNeural',
+ '中原官话 (中国陕西)-Xiaoni-女': 'zh-CN-shaanxi-XiaoniNeural',
+ '粤语 (中国香港)-HiuMaan-女': 'zh-HK-HiuMaanNeural',
+ '粤语 (中国香港)-HiuGaai-女': 'zh-HK-HiuGaaiNeural',
+ '粤语 (中国香港)-WanLung-男': 'zh-HK-WanLungNeural',
+ '台湾普通话-HsiaoChen-女': 'zh-TW-HsiaoChenNeural',
+ '台湾普通话-HsiaoYu-女': 'zh-TW-HsiaoYuNeural',
+ '台湾普通话-YunJhe-男': 'zh-TW-YunJheNeural'}
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd5b6185af6c9f1c270b8ba345bfc36d059e081
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,305 @@
+import os
+import sys
+import argparse
+import logging
+import json
+import subprocess
+import numpy as np
+from scipy.io.wavfile import read
+import torch
+from torch.nn import functional as F
+from commons import sequence_mask
+
+MATPLOTLIB_FLAG = False
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
+logger = logging
+
+
+def get_cmodel(rank):
+ checkpoint = torch.load('wavlm/WavLM-Large.pt')
+ cfg = WavLMConfig(checkpoint['cfg'])
+ cmodel = WavLM(cfg).cuda(rank)
+ cmodel.load_state_dict(checkpoint['model'])
+ cmodel.eval()
+ return cmodel
+
+
+def get_content(cmodel, y):
+ with torch.no_grad():
+ c = cmodel.extract_features(y.squeeze(1))[0]
+ c = c.transpose(1, 2)
+ return c
+
+
+def get_vocoder(rank):
+ with open("hifigan/config.json", "r") as f:
+ config = json.load(f)
+ config = hifigan.AttrDict(config)
+ vocoder = hifigan.Generator(config)
+ ckpt = torch.load("hifigan/generator_v1")
+ vocoder.load_state_dict(ckpt["generator"])
+ vocoder.eval()
+ vocoder.remove_weight_norm()
+ vocoder.cuda(rank)
+ return vocoder
+
+
+def transform(mel, height): # 68-92
+ #r = np.random.random()
+ #rate = r * 0.3 + 0.85 # 0.85-1.15
+ #height = int(mel.size(-2) * rate)
+ tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
+ if height >= mel.size(-2):
+ return tgt[:, :mel.size(-2), :]
+ else:
+ silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1)
+ silence += torch.randn_like(silence) / 10
+ return torch.cat((tgt, silence), 1)
+
+
+def stretch(mel, width): # 0.5-2
+ return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
+
+
+def load_checkpoint(checkpoint_path, model, optimizer=None):
+ assert os.path.isfile(checkpoint_path)
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
+ iteration = checkpoint_dict['iteration']
+ learning_rate = checkpoint_dict['learning_rate']
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
+ saved_state_dict = checkpoint_dict['model']
+ if hasattr(model, 'module'):
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+ new_state_dict= {}
+ for k, v in state_dict.items():
+ try:
+ new_state_dict[k] = saved_state_dict[k]
+ except:
+ logger.info("%s is not in the checkpoint" % k)
+ new_state_dict[k] = v
+ if hasattr(model, 'module'):
+ model.module.load_state_dict(new_state_dict)
+ else:
+ model.load_state_dict(new_state_dict)
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
+ checkpoint_path, iteration))
+ return model, optimizer, learning_rate, iteration
+
+
+def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
+ iteration, checkpoint_path))
+ if hasattr(model, 'module'):
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+ torch.save({'model': state_dict,
+ 'iteration': iteration,
+ 'optimizer': optimizer.state_dict(),
+ 'learning_rate': learning_rate}, checkpoint_path)
+
+
+def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
+ for k, v in scalars.items():
+ writer.add_scalar(k, v, global_step)
+ for k, v in histograms.items():
+ writer.add_histogram(k, v, global_step)
+ for k, v in images.items():
+ writer.add_image(k, v, global_step, dataformats='HWC')
+ for k, v in audios.items():
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
+
+
+def latest_checkpoint_path(dir_path, regex="G_*.pth"):
+ f_list = glob.glob(os.path.join(dir_path, regex))
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
+ x = f_list[-1]
+ print(x)
+ return x
+
+
+def plot_spectrogram_to_numpy(spectrogram):
+ global MATPLOTLIB_FLAG
+ if not MATPLOTLIB_FLAG:
+ import matplotlib
+ matplotlib.use("Agg")
+ MATPLOTLIB_FLAG = True
+ mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger.setLevel(logging.WARNING)
+ import matplotlib.pylab as plt
+ import numpy as np
+
+ fig, ax = plt.subplots(figsize=(10,2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none')
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plt.close()
+ return data
+
+
+def plot_alignment_to_numpy(alignment, info=None):
+ global MATPLOTLIB_FLAG
+ if not MATPLOTLIB_FLAG:
+ import matplotlib
+ matplotlib.use("Agg")
+ MATPLOTLIB_FLAG = True
+ mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger.setLevel(logging.WARNING)
+ import matplotlib.pylab as plt
+ import numpy as np
+
+ fig, ax = plt.subplots(figsize=(6, 4))
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
+ interpolation='none')
+ fig.colorbar(im, ax=ax)
+ xlabel = 'Decoder timestep'
+ if info is not None:
+ xlabel += '\n\n' + info
+ plt.xlabel(xlabel)
+ plt.ylabel('Encoder timestep')
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plt.close()
+ return data
+
+
+def load_wav_to_torch(full_path):
+ sampling_rate, data = read(full_path)
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
+
+
+def load_filepaths_and_text(filename, split="|"):
+ with open(filename, encoding='utf-8') as f:
+ filepaths_and_text = [line.strip().split(split) for line in f]
+ return filepaths_and_text
+
+
+def get_hparams(init=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
+ help='JSON file for configuration')
+ parser.add_argument('-m', '--model', type=str, required=True,
+ help='Model name')
+
+ args = parser.parse_args()
+ model_dir = os.path.join("./logs", args.model)
+
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+
+ config_path = args.config
+ config_save_path = os.path.join(model_dir, "config.json")
+ if init:
+ with open(config_path, "r") as f:
+ data = f.read()
+ with open(config_save_path, "w") as f:
+ f.write(data)
+ else:
+ with open(config_save_path, "r") as f:
+ data = f.read()
+ config = json.loads(data)
+
+ hparams = HParams(**config)
+ hparams.model_dir = model_dir
+ return hparams
+
+
+def get_hparams_from_dir(model_dir):
+ config_save_path = os.path.join(model_dir, "config.json")
+ with open(config_save_path, "r") as f:
+ data = f.read()
+ config = json.loads(data)
+
+ hparams =HParams(**config)
+ hparams.model_dir = model_dir
+ return hparams
+
+
+def get_hparams_from_file(config_path):
+ with open(config_path, "r") as f:
+ data = f.read()
+ config = json.loads(data)
+
+ hparams =HParams(**config)
+ return hparams
+
+
+def check_git_hash(model_dir):
+ source_dir = os.path.dirname(os.path.realpath(__file__))
+ if not os.path.exists(os.path.join(source_dir, ".git")):
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
+ source_dir
+ ))
+ return
+
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
+
+ path = os.path.join(model_dir, "githash")
+ if os.path.exists(path):
+ saved_hash = open(path).read()
+ if saved_hash != cur_hash:
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
+ saved_hash[:8], cur_hash[:8]))
+ else:
+ open(path, "w").write(cur_hash)
+
+
+def get_logger(model_dir, filename="train.log"):
+ global logger
+ logger = logging.getLogger(os.path.basename(model_dir))
+ logger.setLevel(logging.DEBUG)
+
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+ h = logging.FileHandler(os.path.join(model_dir, filename))
+ h.setLevel(logging.DEBUG)
+ h.setFormatter(formatter)
+ logger.addHandler(h)
+ return logger
+
+
+class HParams():
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ if type(v) == dict:
+ v = HParams(**v)
+ self[k] = v
+
+ def keys(self):
+ return self.__dict__.keys()
+
+ def items(self):
+ return self.__dict__.items()
+
+ def values(self):
+ return self.__dict__.values()
+
+ def __len__(self):
+ return len(self.__dict__)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value):
+ return setattr(self, key, value)
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ def __repr__(self):
+ return self.__dict__.__repr__()