diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..4c8dc31416176339d3ef7d347d291ab104bda1a7 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+PandaGPT.pdf filter=lfs diff=lfs merge=lfs -text
+PandaGPT.png filter=lfs diff=lfs merge=lfs -text
+code/assets/videos/world.mp4 filter=lfs diff=lfs merge=lfs -text
+code/pytorchvideo/.github/media/ava_slowfast.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/PandaGPT.pdf b/PandaGPT.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..95f171da1f669abe4c222831c8dfbff01a1503ef
--- /dev/null
+++ b/PandaGPT.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:808a3bb9c27e315246119bc401b802a51270ef65147acce4fe9f29f1d9c25b9a
+size 8340300
diff --git a/PandaGPT.png b/PandaGPT.png
new file mode 100644
index 0000000000000000000000000000000000000000..c9ff0cf375a30983ae23208fc7b15e1030f1a61d
--- /dev/null
+++ b/PandaGPT.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3452d3ab3f66c0e716ad9da4cf87e4540bc8b1675be7983992d2443b9dbced29
+size 1313690
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d65e4faf342d7ab3df6b87d79b84dd52fdfa9776
--- /dev/null
+++ b/README.md
@@ -0,0 +1,249 @@
+
+
+
+
+# PandaGPT: One Model To Instruction-Follow Them All
+
+
+
+
+
+
+
+
+ 🌐 Project Page • 🤗 Online Demo • 🤗 Online Demo-2 (Runs fast for users from mainland China) • 📃 Paper • ⏬ Data • 🤖 Model • 📹 Video
+
+
+
+**Team:** [Yixuan Su](https://yxuansu.github.io/)\* , [Tian Lan](https://github.com/gmftbyGMFTBY)\* , [Huayang Li](https://sites.google.com/view/huayangli)\* , Jialu Xu, Yan Wang, and [Deng Cai](https://jcyk.github.io/)\* (Major contributors\* )
+
+****
+
+## Online Demo Demonstration:
+
+Below, we demonstrate some examples of our online [demo](https://huggingface.co/spaces/GMFTBY/PandaGPT). For more generated examples of PandaGPT, please refer to our [webpage](https://panda-gpt.github.io/) or our [paper](https://github.com/yxuansu/PandaGPT/blob/main/PandaGPT.pdf).
+
+
+
+
+
+(1) In this example, PandaGPT takes an input image and reasons over the user's input.
+
+
+
+
+
+(2) In this example, PandaGPT takes the joint input from two modalities, i.e. (1) an image 👀 of car and (2) an audio 👂 of thunderstorm.
+
+
+****
+
+
+
+## Catalogue:
+* 1. Introduction
+* 2. Running PandaGPT Demo
+ * 2.1. Environment Installation
+ * 2.2. Prepare ImageBind Checkpoint
+ * 2.3. Prepare Vicuna Checkpoint
+ * 2.4. Prepare Delta Weights of PandaGPT
+ * 2.5. Deploying Demo
+* 3. Train Your Own PandaGPT
+ * 3.1. Data Preparation
+ * 3.2. Training Configurations
+ * 3.3. Training PandaGPT
+* Usage and License Notices
+* Citation
+* Acknowledgments
+
+****
+
+
+
+### 1. Introduction: [Back to Top]
+
+
+
+
+
+**License** The icons in the image are taken from [this website](https://www.flaticon.com).
+
+
+PandaGPT is the first foundation model capable of instruction-following data across six modalities, without the need of explicit supervision. It demonstrates a diverse set of multimodal capabilities such as complex understanding/reasoning, knowledge-grounded description, and multi-turn conversation.
+
+PandaGPT is a general-purpose instruction-following model that can both see 👀 and hear👂 . Our pilot experiments show that PandaGPT can perform complex tasks such as detailed image description generation, writing stories inspired by videos, and answering questions about audios. More Interestingly, PandaGPT can take multimodal inputs simultaneously and compose their semantics naturally. For example, PandaGPT can connect how objects look in a photo and how they sound in an audio.
+
+
+****
+
+
+
+### 2. Running PandaGPT Demo: [Back to Top]
+
+
+
+#### 2.1. Environment Installation:
+To install the required environment, please run
+```
+pip install -r requirements.txt
+```
+
+Then install the Pytorch package with the correct cuda version, for example
+```
+pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch/
+```
+
+
+
+#### 2.2. Prepare ImageBind Checkpoint:
+You can download the pre-trained ImageBind model using [this link](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth). After downloading, put the downloaded file (imagebind_huge.pth) in [[./pretrained_ckpt/imagebind_ckpt/]](./pretrained_ckpt/imagebind_ckpt/) directory.
+
+
+
+#### 2.3. Prepare Vicuna Checkpoint:
+To prepare the pre-trained Vicuna model, please follow the instructions provided [[here]](./pretrained_ckpt#1-prepare-vicuna-checkpoint).
+
+
+
+
+#### 2.4. Prepare Delta Weights of PandaGPT:
+
+|**Base Language Model**|**Maximum Sequence Length**|**Huggingface Delta Weights Address**|
+|:-------------:|:-------------:|:-------------:|
+|Vicuna-7B (version 0)|512|[openllmplayground/pandagpt_7b_max_len_512](https://huggingface.co/openllmplayground/pandagpt_7b_max_len_512)|
+|Vicuna-7B (version 0)|1024|[openllmplayground/pandagpt_7b_max_len_1024](https://huggingface.co/openllmplayground/pandagpt_7b_max_len_1024)|
+|Vicuna-13B (version 0)|256|[openllmplayground/pandagpt_13b_max_len_256](https://huggingface.co/openllmplayground/pandagpt_13b_max_len_256)|
+|Vicuna-13B (version 0)|400|[openllmplayground/pandagpt_13b_max_len_400](https://huggingface.co/openllmplayground/pandagpt_13b_max_len_400)|
+
+We release the delta weights of PandaGPT trained with different strategies in the table above. After downloading, put the downloaded 7B/13B delta weights file (pytorch_model.pt) in the [./pretrained_ckpt/pandagpt_ckpt/7b/](./pretrained_ckpt/pandagpt_ckpt/7b/) or [./pretrained_ckpt/pandagpt_ckpt/13b/](./pretrained_ckpt/pandagpt_ckpt/13b/) directory. In our [online demo](https://huggingface.co/spaces/GMFTBY/PandaGPT), we use the `openllmplayground/pandagpt_7b_max_len_1024` as our default model due to the limitation of computation resource. Better results are expected if switching to `openllmplayground/pandagpt_13b_max_len_400`.
+
+
+
+#### 2.5. Deploying Demo:
+Upon completion of previous steps, you can run the demo locally as
+```bash
+cd ./code/
+CUDA_VISIBLE_DEVICES=0 python web_demo.py
+```
+
+If you running into `sample_rate` problem, please git install `pytorchvideo` from the source as
+```yaml
+git clone https://github.com/facebookresearch/pytorchvideo
+cd pytorchvideo
+pip install --editable ./
+```
+
+****
+
+
+
+### 3. Train Your Own PandaGPT: [Back to Top]
+
+**Prerequisites:** Before training the model, making sure the environment is properly installed and the checkpoints of ImageBind and Vicuna are downloaded. You can refer to [here](https://github.com/yxuansu/PandaGPT#2-running-pandagpt-demo-back-to-top) for more information.
+
+
+
+#### 3.1. Data Preparation:
+
+**Declaimer:** To ensure the reproducibility of our results, we have released our training dataset. The dataset must be used for research purpose only. The use of the dataset must comply with the licenses from original sources, i.e. LLaVA and MiniGPT-4. These datasets may be taken down when requested by the original authors.
+
+|**Training Task**|**Dataset Address**|
+|:-------------:|:-------------:|
+|Visual Instruction-Following|[openllmplayground/pandagpt_visual_instruction_dataset](https://huggingface.co/datasets/openllmplayground/pandagpt_visual_instruction_dataset)|
+
+After downloading, put the downloaded file and unzip them under the [./data/](./data/) directory.
+
+> **** The directory should look like:
+
+ .
+ └── ./data/
+ ├── pandagpt4_visual_instruction_data.json
+ └── /images/
+ ├── 000000426538.jpg
+ ├── 000000306060.jpg
+ └── ...
+
+
+
+
+#### 3.2 Training Configurations:
+
+The table below show the training hyperparameters used in our experiments. The hyperparameters are selected based on the constrain of our computational resources, i.e. 8 x A100 (40G) GPUs.
+
+|**Base Language Model**|**Training Task**|**Epoch Number**|**Batch Size**|**Learning Rate**|**Maximum Length**|
+|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|
+|7B|Visual Instruction|2|64|5e-4|1024|
+|13B|Visual Instruction|2|64|5e-4|400|
+
+
+
+
+
+
+#### 3.3. Training PandaGPT:
+
+To train PandaGPT, please run the following commands:
+```yaml
+cd ./code/scripts/
+chmod +x train.sh
+cd ..
+./scripts/train.sh
+```
+
+The key arguments of the training script are as follows:
+* `--data_path`: The data path for the json file `pandagpt4_visual_instruction_data.json`.
+* `--image_root_path`: The root path for the downloaded images.
+* `--imagebind_ckpt_path`: The path where saves the ImageBind checkpoint `imagebind_huge.pth`.
+* `--vicuna_ckpt_path`: The directory that saves the pre-trained Vicuna checkpoints.
+* `--max_tgt_len`: The maximum sequence length of training instances.
+* `--save_path`: The directory which saves the trained delta weights. This directory will be automatically created.
+
+Note that the epoch number can be set in the `epochs` argument at [./code/config/openllama_peft.yaml](./code/config/openllama_peft.yaml) file. The `train_micro_batch_size_per_gpu` and `gradient_accumulation_steps` arguments in [./code/dsconfig/openllama_peft_stage_1.json](./code/dsconfig/openllama_peft_stage_1.json) should be set as `2` and `4` for 7B model, and set as `1` and `8` for 13B model.
+
+****
+
+
+
+### Usage and License Notices:
+
+PandaGPT is intended and licensed for research use only. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes. The delta weights are also CC BY NC 4.0 (allowing only non-commercial use).
+
+
+****
+
+
+
+### Citation:
+
+If you found PandaGPT useful in your research or applications, please kindly cite using the following BibTeX:
+```
+@article{su2023pandagpt,
+ title={PandaGPT: One Model To Instruction-Follow Them All},
+ author={Su, Yixuan and Lan, Tian and Li, Huayang and Xu, Jialu and Wang, Yan and Cai, Deng},
+ journal={arXiv preprint arXiv:2305.16355},
+ year={2023}
+}
+```
+
+
+****
+
+
+
+### Acknowledgments:
+
+
+This repo benefits from [OpenAlpaca](https://github.com/yxuansu/OpenAlpaca), [ImageBind](https://github.com/facebookresearch/ImageBind), [LLaVA](https://github.com/haotian-liu/LLaVA), and [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4). Thanks for their wonderful works!
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/code/__pycache__/header.cpython-310.pyc b/code/__pycache__/header.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cc69b89054348a3a17178b7a86f6b1740fe683e
Binary files /dev/null and b/code/__pycache__/header.cpython-310.pyc differ
diff --git a/code/assets/audios/bird_audio.wav b/code/assets/audios/bird_audio.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a98fc72b0df440fd10b3e54c87dfe0ffae0fa12e
Binary files /dev/null and b/code/assets/audios/bird_audio.wav differ
diff --git a/code/assets/audios/car_audio.wav b/code/assets/audios/car_audio.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b71b42a3a375b763521d08855f1a1eebb647a3d2
Binary files /dev/null and b/code/assets/audios/car_audio.wav differ
diff --git a/code/assets/audios/dog_audio.wav b/code/assets/audios/dog_audio.wav
new file mode 100644
index 0000000000000000000000000000000000000000..71d69c77e92039d5906ed766d9c3ca4b181f9ffd
Binary files /dev/null and b/code/assets/audios/dog_audio.wav differ
diff --git a/code/assets/images/bird_image.jpg b/code/assets/images/bird_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..78b10ab1fe76f42e3dda1dc515e69312f02713d9
Binary files /dev/null and b/code/assets/images/bird_image.jpg differ
diff --git a/code/assets/images/car_image.jpg b/code/assets/images/car_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e33288eb765882c594f479bfb35d941fd51a19b1
Binary files /dev/null and b/code/assets/images/car_image.jpg differ
diff --git a/code/assets/images/dog_image.jpg b/code/assets/images/dog_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a54bffa5c80869c6b96246ba29c9e2462c698e3b
Binary files /dev/null and b/code/assets/images/dog_image.jpg differ
diff --git a/code/assets/thermals/190662.jpg b/code/assets/thermals/190662.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bb690a9f7db7568e86077f1017174df551f3c306
Binary files /dev/null and b/code/assets/thermals/190662.jpg differ
diff --git a/code/assets/thermals/210009.jpg b/code/assets/thermals/210009.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f500c611eb76c7dc82d2865513100bee1df99949
Binary files /dev/null and b/code/assets/thermals/210009.jpg differ
diff --git a/code/assets/videos/a.mp4 b/code/assets/videos/a.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..95a61f6b4a753497d97f51c6a8f18727cef7d628
Binary files /dev/null and b/code/assets/videos/a.mp4 differ
diff --git a/code/assets/videos/world.mp4 b/code/assets/videos/world.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9bce44c33e275d6107240a1101032a7835fd8eed
--- /dev/null
+++ b/code/assets/videos/world.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71944d7430c461f0cd6e7fd10cee7eb72786352a3678fc7bc0ae3d410f72aece
+size 1570024
diff --git a/code/config/__init__.py b/code/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..826b6ef41067725c02ac33210e773bb1a8123896
--- /dev/null
+++ b/code/config/__init__.py
@@ -0,0 +1,37 @@
+import yaml
+
+def load_model_config(model, mode):
+ # load special config for each model
+ config_path = f'config/{model}.yaml'
+ print(f'[!] load configuration from {config_path}')
+ with open(config_path) as f:
+ configuration = yaml.load(f, Loader=yaml.FullLoader)
+ new_config = {}
+ for key, value in configuration.items():
+ if key in ['train', 'test', 'validation']:
+ if mode == key:
+ new_config.update(value)
+ else:
+ new_config[key] = value
+ configuration = new_config
+ return configuration
+
+def load_config(args):
+ '''the configuration of each model can rewrite the base configuration'''
+ # base config
+ base_configuration = load_base_config()
+
+ # load one model config
+ configuration = load_model_config(args['model'], args['mode'])
+
+ # update and append the special config for base config
+ base_configuration.update(configuration)
+ configuration = base_configuration
+ return configuration
+
+def load_base_config():
+ config_path = f'config/base.yaml'
+ with open(config_path) as f:
+ configuration = yaml.load(f, Loader=yaml.FullLoader)
+ print(f'[!] load base configuration: {config_path}')
+ return configuration
diff --git a/code/config/base.yaml b/code/config/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3385ecadf2b64640bf46a19452c0b43342084cd
--- /dev/null
+++ b/code/config/base.yaml
@@ -0,0 +1,15 @@
+models:
+ openllama:
+ model_name: OpenLLAMAModel
+ agent_name: DeepSpeedAgent
+ stage1_train_dataset: SupervisedDataset
+ test_dataset: SelfInstructTestDataset
+ openllama_peft:
+ model_name: OpenLLAMAPEFTModel
+ agent_name: DeepSpeedAgent
+ stage1_train_dataset: SupervisedDataset
+ test_dataset: SelfInstructTestDataset
+
+# ========= Global configuration ========== #
+logging_step: 5
+# ========= Global configuration ========== #
diff --git a/code/config/openllama_peft.yaml b/code/config/openllama_peft.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1ea52542901ed39498a7ac0c4a2a1a02e950c5a6
--- /dev/null
+++ b/code/config/openllama_peft.yaml
@@ -0,0 +1,22 @@
+# generation hyper-parameters
+max_len: 512
+penalty_alpha: 0.6
+top_k: 10
+top_p: 0.7
+random_prefix_len: 5
+sample_num: 2
+decoding_method: sampling
+generate_len: 512
+
+# lora hyper-parameters
+lora_r: 32
+lora_alpha: 32
+lora_dropout: 0.1
+
+# some train configuration, more can be found under dsconfig folder
+train:
+ seed: 0
+ warmup_rate: 0.1
+ epochs: 2
+ max_length: 1024
+ max_shard_size: 10GB
diff --git a/code/datasets/__init__.py b/code/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..66354326c55408308a65f7d9f2dbf56c4555541e
--- /dev/null
+++ b/code/datasets/__init__.py
@@ -0,0 +1,40 @@
+from header import *
+from .samplers import DistributedBatchSampler
+from .sft_dataset import *
+
+'''
+def get_tokenizer(model):
+ tokenizer = LlamaTokenizer.from_pretrained(model)
+ tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2
+ tokenizer.pad_token = tokenizer.eos_token
+ return tokenizer
+'''
+
+def load_sft_dataset(args):
+ '''
+ tokenizer = get_tokenizer(args['model_path'])
+ dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str
+ data_path = args["data_path"]
+ data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset
+ '''
+ data = SupervisedDataset(args['data_path'], args['image_root_path'])
+
+ sampler = torch.utils.data.RandomSampler(data)
+ world_size = torch.distributed.get_world_size()
+ rank = torch.distributed.get_rank()
+ batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu']
+ batch_sampler = DistributedBatchSampler(
+ sampler,
+ batch_size,
+ True,
+ rank,
+ world_size
+ )
+ iter_ = DataLoader(
+ data,
+ batch_sampler=batch_sampler,
+ num_workers=1,
+ collate_fn=data.collate,
+ pin_memory=True
+ )
+ return data, iter_, sampler
diff --git a/code/datasets/samplers.py b/code/datasets/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3ce1e90b2177940acb911d31d1c5245d74a6119
--- /dev/null
+++ b/code/datasets/samplers.py
@@ -0,0 +1,166 @@
+# coding=utf-8
+# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""batch samplers that work with either random or sequential data samplers"""
+import math
+import os
+import sys
+
+import torch
+from torch.utils import data
+import numpy as np
+
+
+class RandomSampler(data.sampler.Sampler):
+ r"""
+ Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
+ but this class lets the user set an epoch like DistributedSampler
+ Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+ If with replacement, then user can specify ``num_samples`` to draw.
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ num_samples (int): number of samples to draw, default=len(dataset)
+ replacement (bool): samples are drawn with replacement if ``True``, default=False
+ """
+
+ def __init__(self, data_source, replacement=False, num_samples=None):
+ super(RandomSampler, self).__init__(data_source)
+ self.data_source = data_source
+ self.replacement = replacement
+ self._num_samples = num_samples
+ self.epoch = -1
+
+ if self._num_samples is not None and replacement is False:
+ raise ValueError("With replacement=False, num_samples should not be specified, "
+ "since a random permute will be performed.")
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError("num_samples should be a positive integer "
+ "value, but got num_samples={}".format(self.num_samples))
+ if not isinstance(self.replacement, bool):
+ raise ValueError("replacement should be a boolean value, but got "
+ "replacement={}".format(self.replacement))
+
+ @property
+ def num_samples(self):
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
+ def __iter__(self):
+ n = len(self.data_source)
+ g = torch.Generator()
+ if self.epoch >= 0:
+ g.manual_seed(self.epoch)
+ if self.replacement:
+ for _ in range(self.num_samples // 32):
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64,
+ generator=g).tolist()
+ else:
+ yield from torch.randperm(n, generator=self.generator).tolist()
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class DistributedSequentialSampler(data.sampler.Sampler):
+ def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2):
+ super().__init__(num_samples)
+ if rank == -1:
+ rank = 0
+ world_size = 1
+ self.num_samples = num_samples
+ self.rank = rank
+ self.world_size = world_size
+ self.start_iter = 0
+ self.train_iters = train_iters
+ self.batch_size = batch_size
+ self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)]
+
+ def __iter__(self):
+ for idx in range(self.start_iter, self.train_iters * 10):
+ batch = [(idx + bias) % self.num_samples for bias in self.batch_bias]
+ tbatch = self._batch(batch)
+ yield tbatch
+
+ def __len__(self):
+ return self.train_iters
+
+ def _batch(self, batch):
+ """extracts samples only pertaining to this worker's batch"""
+ start = self.rank*self.batch_size//self.world_size
+ end = (self.rank+1)*self.batch_size//self.world_size
+ return batch[start:end]
+
+
+class DistributedBatchSampler(data.sampler.BatchSampler):
+ """
+ similar to normal implementation of distributed sampler, except implementation is at the
+ batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
+ data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
+ """
+ def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None):
+ super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
+ if rank == -1:
+ assert False, 'should not be here'
+ self.rank = rank
+ self.world_size = world_size
+ self.sampler.wrap_around = 0
+ self.wrap_around = 0
+ self.wrap_last = wrap_last
+ self.start_iter = 0
+ self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps
+
+ def __iter__(self):
+ batch = []
+ i = 0
+ for idx in self.data_iterator(self.sampler, wrap_around=False):
+ batch.append(idx)
+ if len(batch) == self.batch_size:
+ tbatch = self._batch(batch)
+ if i >= self.start_iter * self.effective_batch_size:
+ yield tbatch
+ self.start_iter = 0
+ i += len(batch)
+ batch = []
+ batch_len = len(batch)
+ if batch_len > 0 and not self.drop_last:
+ if self.wrap_last:
+ self.sampler.wrap_around -= (self.batch_size)
+ self.wrap_around += (len(batch))
+ self.wrap_around %= self.batch_size
+ yield self._batch(batch)
+ if self.wrap_last:
+ self.sampler.wrap_around += self.batch_size
+
+ def data_iterator(self, _iter, wrap_around=False):
+ """iterates through data and handles wrap around"""
+ for i, idx in enumerate(_iter):
+ if i < self.wrap_around%self.batch_size:
+ continue
+ if wrap_around:
+ self.wrap_around += 1
+ self.wrap_around %= self.batch_size
+ yield idx
+
+ def _batch(self, batch):
+ """extracts samples only pertaining to this worker's batch"""
+ start = self.rank*self.batch_size//self.world_size
+ end = (self.rank+1)*self.batch_size//self.world_size
+ return batch[start:end]
diff --git a/code/datasets/sft_dataset.py b/code/datasets/sft_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfc64dd37d912d539c0600e5965ad1e5c87a6c1d
--- /dev/null
+++ b/code/datasets/sft_dataset.py
@@ -0,0 +1,65 @@
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import json
+from tqdm import tqdm
+import ipdb
+import random
+from torch.nn.utils.rnn import pad_sequence
+from dataclasses import dataclass, field
+from typing import Callable, Dict, Sequence
+
+import torch
+import torch.distributed as dist
+import transformers
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+class SupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str, image_root_path: str):
+ super(SupervisedDataset, self).__init__()
+
+ with open(data_path, 'r') as f:
+ json_data = json.load(f)
+ # for debug:
+ #json_data = json_data[:100000]
+
+ self.image_path_list, self.caption_list = [], []
+ for item in json_data:
+ one_image_name, one_caption = item["image_name"], item["conversation"]
+ # TODO: stage 2 dataset format is invalid
+ if not one_image_name.endswith('.jpg'):
+ one_image_name += '.jpg'
+ one_image_path = image_root_path + '/{}'.format(one_image_name)
+ self.image_path_list.append(one_image_path)
+ self.caption_list.append(one_caption)
+ print(f'[!] collect {len(self.image_path_list)} samples for training')
+
+ def __len__(self): # number of instances
+ return len(self.image_path_list)
+
+ #def __getitem__(self, i) -> Dict[str, torch.Tensor]: # how to get item, 取一个样本
+ def __getitem__(self, i):
+ return dict(image_paths=self.image_path_list[i], output_texts=self.caption_list[i])
+
+ def collate(self, instances):
+ image_paths, output_texts = tuple([instance[key] for instance in instances] for key in ("image_paths", "output_texts"))
+ return dict(
+ image_paths=image_paths,
+ output_texts=output_texts
+ )
diff --git a/code/dsconfig/openllama_peft_stage_1.json b/code/dsconfig/openllama_peft_stage_1.json
new file mode 100644
index 0000000000000000000000000000000000000000..ff78d81809d08e660c60732d06f27ec9ffea996a
--- /dev/null
+++ b/code/dsconfig/openllama_peft_stage_1.json
@@ -0,0 +1,54 @@
+{
+ "train_batch_size": 64,
+ "train_micro_batch_size_per_gpu": 1,
+ "gradient_accumulation_steps": 8,
+ "steps_per_print": 1,
+ "gradient_clipping": 1.0,
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "cpu"
+ },
+ "contiguous_gradients": true,
+ "allgather_bucket_size": 500000000,
+ "allgather_partitions": true
+ },
+ "fp16": {
+ "enabled": true,
+ "opt_level": "O2",
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enable": true
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 0.0005,
+ "betas": [
+ 0.9,
+ 0.95
+ ],
+ "eps": 1e-8,
+ "weight_decay": 0.001
+ }
+ },
+ "scheduler": {
+ "type": "WarmupDecayLR",
+ "params": {
+ "warmup_min_lr": 0,
+ "warmup_max_lr": 0.0005,
+ "warmup_num_steps": 10,
+ "total_num_steps": 10000
+ }
+ },
+ "activation_checkpointing": {
+ "partition_activations": true,
+ "cpu_checkpointing": true,
+ "contiguous_memory_optimization": false,
+ "number_checkpoints": null,
+ "synchronize_checkpoint_boundary": false,
+ "profile": false
+ }
+
+}
\ No newline at end of file
diff --git a/code/header.py b/code/header.py
new file mode 100644
index 0000000000000000000000000000000000000000..97338165d32d531838566ade9c9217182bb8ea67
--- /dev/null
+++ b/code/header.py
@@ -0,0 +1,35 @@
+import torch
+import datetime
+import types
+import deepspeed
+from transformers.deepspeed import HfDeepSpeedConfig
+import transformers
+import numpy as np
+from collections import OrderedDict
+from torch.utils.data import Dataset, DataLoader
+from torch.nn.utils import clip_grad_norm_
+from torch.cuda.amp import autocast, GradScaler
+from torch.nn import DataParallel
+from torch.optim import lr_scheduler
+import torch.optim as optim
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+import os
+import re
+import math
+import random
+import json
+import time
+import logging
+from copy import deepcopy
+import ipdb
+import argparse
+import data
+from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
+from torch.nn.utils.rnn import pad_sequence
+from peft import LoraConfig, TaskType, get_peft_model
+
+logging.getLogger("transformers").setLevel(logging.WARNING)
+logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
+os.environ['TOKENIZERS_PARALLELISM'] = 'false'
diff --git a/code/model/ImageBind/CODE_OF_CONDUCT.md b/code/model/ImageBind/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..f913b6a55a6c5ab6e1224e11fc039c3d4c3b6283
--- /dev/null
+++ b/code/model/ImageBind/CODE_OF_CONDUCT.md
@@ -0,0 +1,80 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+This Code of Conduct also applies outside the project spaces when there is a
+reasonable belief that an individual's behavior may have a negative impact on
+the project or its community.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at . All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
\ No newline at end of file
diff --git a/code/model/ImageBind/CONTRIBUTING.md b/code/model/ImageBind/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..63d0b751e8a00b606ddff92e2524faa3c90a63b0
--- /dev/null
+++ b/code/model/ImageBind/CONTRIBUTING.md
@@ -0,0 +1,31 @@
+# Contributing to ImageBind
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+We actively welcome your pull requests.
+
+1. Fork the repo and create your branch from `main`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Meta's open source projects.
+
+Complete your CLA here:
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## License
+By contributing to Omnivore, you agree that your contributions will be licensed
+under the [LICENSE](LICENSE) file in the root directory of this source tree.
diff --git a/code/model/ImageBind/LICENSE b/code/model/ImageBind/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..bfef380bf7d9cb74ec9ba533b37c3fbeef3bdc09
--- /dev/null
+++ b/code/model/ImageBind/LICENSE
@@ -0,0 +1,437 @@
+Attribution-NonCommercial-ShareAlike 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
+Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-ShareAlike 4.0 International Public License
+("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. BY-NC-SA Compatible License means a license listed at
+ creativecommons.org/compatiblelicenses, approved by Creative
+ Commons as essentially the equivalent of this Public License.
+
+ d. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ e. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ g. License Elements means the license attributes listed in the name
+ of a Creative Commons Public License. The License Elements of this
+ Public License are Attribution, NonCommercial, and ShareAlike.
+
+ h. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ i. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ j. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ k. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ l. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ m. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ n. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. Additional offer from the Licensor -- Adapted Material.
+ Every recipient of Adapted Material from You
+ automatically receives an offer from the Licensor to
+ exercise the Licensed Rights in the Adapted Material
+ under the conditions of the Adapter's License You apply.
+
+ c. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ b. ShareAlike.
+
+ In addition to the conditions in Section 3(a), if You Share
+ Adapted Material You produce, the following conditions also apply.
+
+ 1. The Adapter's License You apply must be a Creative Commons
+ license with the same License Elements, this version or
+ later, or a BY-NC-SA Compatible License.
+
+ 2. You must include the text of, or the URI or hyperlink to, the
+ Adapter's License You apply. You may satisfy this condition
+ in any reasonable manner based on the medium, means, and
+ context in which You Share Adapted Material.
+
+ 3. You may not offer or impose any additional or different terms
+ or conditions on, or apply any Effective Technological
+ Measures to, Adapted Material that restrict exercise of the
+ rights granted under the Adapter's License You apply.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material,
+ including for purposes of Section 3(b); and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/code/model/ImageBind/README.md b/code/model/ImageBind/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..028fa988bb6cd9843aec9454636e1541b53680e7
--- /dev/null
+++ b/code/model/ImageBind/README.md
@@ -0,0 +1,155 @@
+# ImageBind: One Embedding Space To Bind Them All
+
+**[FAIR, Meta AI](https://ai.facebook.com/research/)**
+
+Rohit Girdhar*,
+Alaaeldin El-Nouby*,
+Zhuang Liu,
+Mannat Singh,
+Kalyan Vasudev Alwala,
+Armand Joulin,
+Ishan Misra*
+
+To appear at CVPR 2023 (*Highlighted paper*)
+
+[[`Paper`](https://facebookresearch.github.io/ImageBind/paper)] [[`Blog`](https://ai.facebook.com/blog/imagebind-six-modalities-binding-ai/)] [[`Demo`](https://imagebind.metademolab.com/)] [[`Supplementary Video`](https://dl.fbaipublicfiles.com/imagebind/imagebind_video.mp4)] [[`BibTex`](#citing-imagebind)]
+
+PyTorch implementation and pretrained models for ImageBind. For details, see the paper: **[ImageBind: One Embedding Space To Bind Them All](https://facebookresearch.github.io/ImageBind/paper)**.
+
+ImageBind learns a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. It enables novel emergent applications ‘out-of-the-box’ including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation.
+
+
+
+
+
+## ImageBind model
+
+Emergent zero-shot classification performance.
+
+
+
+ Model
+ IN1k
+ K400
+ NYU-D
+ ESC
+ LLVIP
+ Ego4D
+ download
+
+
+ imagebind_huge
+ 77.7
+ 50.0
+ 54.0
+ 66.9
+ 63.4
+ 25.0
+ checkpoint
+
+
+
+
+## Usage
+
+Install pytorch 1.13+ and other 3rd party dependencies.
+
+```shell
+conda create --name imagebind python=3.8 -y
+conda activate imagebind
+
+pip install -r requirements.txt
+```
+
+For windows users, you might need to install `soundfile` for reading/writing audio files. (Thanks @congyue1977)
+
+```
+pip install soundfile
+```
+
+
+Extract and compare features across modalities (e.g. Image, Text and Audio).
+
+```python
+import data
+import torch
+from models import imagebind_model
+from models.imagebind_model import ModalityType
+
+text_list=["A dog.", "A car", "A bird"]
+image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
+audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"]
+
+device = "cuda:0" if torch.cuda.is_available() else "cpu"
+
+# Instantiate model
+model = imagebind_model.imagebind_huge(pretrained=True)
+model.eval()
+model.to(device)
+
+# Load data
+inputs = {
+ ModalityType.TEXT: data.load_and_transform_text(text_list, device),
+ ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
+ ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
+}
+
+with torch.no_grad():
+ embeddings = model(inputs)
+
+print(
+ "Vision x Text: ",
+ torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
+)
+print(
+ "Audio x Text: ",
+ torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
+)
+print(
+ "Vision x Audio: ",
+ torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
+)
+
+# Expected output:
+#
+# Vision x Text:
+# tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
+# [3.3836e-05, 9.9994e-01, 2.4118e-05],
+# [4.7997e-05, 1.3496e-02, 9.8646e-01]])
+#
+# Audio x Text:
+# tensor([[1., 0., 0.],
+# [0., 1., 0.],
+# [0., 0., 1.]])
+#
+# Vision x Audio:
+# tensor([[0.8070, 0.1088, 0.0842],
+# [0.1036, 0.7884, 0.1079],
+# [0.0018, 0.0022, 0.9960]])
+
+```
+
+## Model card
+Please see the [model card](model_card.md) for details.
+
+## License
+
+ImageBind code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details.
+
+## Contributing
+
+See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
+
+## Citing ImageBind
+
+If you find this repository useful, please consider giving a star :star: and citation
+
+```
+@inproceedings{girdhar2023imagebind,
+ title={ImageBind: One Embedding Space To Bind Them All},
+ author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang
+and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan},
+ booktitle={CVPR},
+ year={2023}
+}
+```
diff --git a/code/model/ImageBind/__init__.py b/code/model/ImageBind/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d872d0725710d6dde3af3b6e05382922f074338b
--- /dev/null
+++ b/code/model/ImageBind/__init__.py
@@ -0,0 +1,2 @@
+from .models import imagebind_model
+from .models.imagebind_model import ModalityType
diff --git a/code/model/ImageBind/__pycache__/__init__.cpython-310.pyc b/code/model/ImageBind/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bbd51163a7726130384b60b48a85addef2104eb
Binary files /dev/null and b/code/model/ImageBind/__pycache__/__init__.cpython-310.pyc differ
diff --git a/code/model/ImageBind/__pycache__/__init__.cpython-39.pyc b/code/model/ImageBind/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..844ac25ca7233169ebc8987eda73536149071e18
Binary files /dev/null and b/code/model/ImageBind/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/model/ImageBind/__pycache__/data.cpython-310.pyc b/code/model/ImageBind/__pycache__/data.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b1ab7d23bc7f1bbef6df5b98c2d0ab5076bb0e5
Binary files /dev/null and b/code/model/ImageBind/__pycache__/data.cpython-310.pyc differ
diff --git a/code/model/ImageBind/__pycache__/data.cpython-39.pyc b/code/model/ImageBind/__pycache__/data.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..782d5bd14e0545c5c7a8259888f21ade7ac1e13e
Binary files /dev/null and b/code/model/ImageBind/__pycache__/data.cpython-39.pyc differ
diff --git a/code/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz b/code/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/code/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/code/model/ImageBind/data.py b/code/model/ImageBind/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..aed592244741f7f5dd394c4eb461d483b95174e5
--- /dev/null
+++ b/code/model/ImageBind/data.py
@@ -0,0 +1,372 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+import torchaudio
+import logging
+
+from .models.multimodal_preprocessors import SimpleTokenizer
+from PIL import Image
+from pytorchvideo import transforms as pv_transforms
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
+from pytorchvideo.data.encoded_video import EncodedVideo
+
+from torchvision import transforms
+from torchvision.transforms._transforms_video import NormalizeVideo
+
+DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
+
+BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
+
+
+def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
+ waveform -= waveform.mean()
+ fbank = torchaudio.compliance.kaldi.fbank(
+ waveform,
+ htk_compat=True,
+ sample_frequency=sample_rate,
+ use_energy=False,
+ window_type="hanning",
+ num_mel_bins=num_mel_bins,
+ dither=0.0,
+ frame_length=25,
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
+ )
+ # Convert to [mel_bins, num_frames] shape
+ fbank = fbank.transpose(0, 1)
+ # Pad to target_length
+ n_frames = fbank.size(1)
+ p = target_length - n_frames
+ # if p is too large (say >20%), flash a warning
+ if abs(p) / n_frames > 0.2:
+ logging.warning(
+ "Large gap between audio n_frames(%d) and "
+ "target_length (%d). Is the audio_target_length "
+ "setting correct?",
+ n_frames,
+ target_length,
+ )
+ # cut and pad
+ if p > 0:
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
+ elif p < 0:
+ fbank = fbank[:, 0:target_length]
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
+ # channel image
+ fbank = fbank.unsqueeze(0)
+ return fbank
+
+
+def get_clip_timepoints(clip_sampler, duration):
+ # Read out all clips in this video
+ all_clips_timepoints = []
+ is_last_clip = False
+ end = 0.0
+ while not is_last_clip:
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
+ all_clips_timepoints.append((start, end))
+ return all_clips_timepoints
+
+
+def load_and_transform_vision_data(image_paths, device):
+ if image_paths is None:
+ return None
+
+ image_ouputs = []
+ for image_path in image_paths:
+ data_transform = transforms.Compose(
+ [
+ transforms.Resize(
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
+ ),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+ ),
+ ]
+ )
+ with open(image_path, "rb") as fopen:
+ image = Image.open(fopen).convert("RGB")
+
+ image = data_transform(image).to(device)
+ image_ouputs.append(image)
+ return torch.stack(image_ouputs, dim=0)
+
+
+def load_and_transform_thermal_data(thermal_paths, device):
+ if thermal_paths is None:
+ return None
+
+ thermal_ouputs = []
+ for thermal_path in thermal_paths:
+ data_transform = transforms.Compose(
+ [
+ transforms.Resize(
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
+ ),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ ]
+ )
+ with open(thermal_path, "rb") as fopen:
+ thermal = Image.open(fopen).convert("L")
+ thermal = data_transform(thermal).to(device)
+ thermal_ouputs.append(thermal)
+ return torch.stack(thermal_ouputs, dim=0)
+
+
+def load_and_transform_text(text, device):
+ if text is None:
+ return None
+ tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
+ tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
+ tokens = torch.cat(tokens, dim=0)
+ return tokens
+
+
+def load_and_transform_audio_data(
+ audio_paths,
+ device,
+ num_mel_bins=128,
+ target_length=204,
+ sample_rate=16000,
+ clip_duration=2,
+ clips_per_video=3,
+ mean=-4.268,
+ std=9.138,
+):
+ if audio_paths is None:
+ return None
+
+ audio_outputs = []
+ clip_sampler = ConstantClipsPerVideoSampler(
+ clip_duration=clip_duration, clips_per_video=clips_per_video
+ )
+
+ for audio_path in audio_paths:
+ waveform, sr = torchaudio.load(audio_path)
+ if sample_rate != sr:
+ waveform = torchaudio.functional.resample(
+ waveform, orig_freq=sr, new_freq=sample_rate
+ )
+ all_clips_timepoints = get_clip_timepoints(
+ clip_sampler, waveform.size(1) / sample_rate
+ )
+ all_clips = []
+ for clip_timepoints in all_clips_timepoints:
+ waveform_clip = waveform[
+ :,
+ int(clip_timepoints[0] * sample_rate) : int(
+ clip_timepoints[1] * sample_rate
+ ),
+ ]
+ waveform_melspec = waveform2melspec(
+ waveform_clip, sample_rate, num_mel_bins, target_length
+ )
+ all_clips.append(waveform_melspec)
+
+ normalize = transforms.Normalize(mean=mean, std=std)
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
+
+ all_clips = torch.stack(all_clips, dim=0)
+ audio_outputs.append(all_clips)
+
+ return torch.stack(audio_outputs, dim=0)
+
+
+def get_clip_timepoints(clip_sampler, duration):
+ # Read out all clips in this video
+ all_clips_timepoints = []
+ is_last_clip = False
+ end = 0.0
+ while not is_last_clip:
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
+ all_clips_timepoints.append((start, end))
+ return all_clips_timepoints
+
+
+def crop_boxes(boxes, x_offset, y_offset):
+ """
+ Peform crop on the bounding boxes given the offsets.
+ Args:
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
+ is `num boxes` x 4.
+ x_offset (int): cropping offset in the x axis.
+ y_offset (int): cropping offset in the y axis.
+ Returns:
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ cropped_boxes = boxes.copy()
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
+
+ return cropped_boxes
+
+
+def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ scale_size (int): optinal. If not None, resize the images to scale_size before
+ performing any crop.
+ Returns:
+ cropped (tensor): images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ assert spatial_idx in [0, 1, 2]
+ ndim = len(images.shape)
+ if ndim == 3:
+ images = images.unsqueeze(0)
+ height = images.shape[2]
+ width = images.shape[3]
+
+ if scale_size is not None:
+ if width <= height:
+ width, height = scale_size, int(height / width * scale_size)
+ else:
+ width, height = int(width / height * scale_size), scale_size
+ images = torch.nn.functional.interpolate(
+ images,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ y_offset = int(math.ceil((height - size) / 2))
+ x_offset = int(math.ceil((width - size) / 2))
+
+ if height > width:
+ if spatial_idx == 0:
+ y_offset = 0
+ elif spatial_idx == 2:
+ y_offset = height - size
+ else:
+ if spatial_idx == 0:
+ x_offset = 0
+ elif spatial_idx == 2:
+ x_offset = width - size
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ if ndim == 3:
+ cropped = cropped.squeeze(0)
+ return cropped, cropped_boxes
+
+
+class SpatialCrop(nn.Module):
+ """
+ Convert the video into 3 smaller clips spatially. Must be used after the
+ temporal crops to get spatial crops, and should be used with
+ -2 in the spatial crop at the slowfast augmentation stage (so full
+ frames are passed in here). Will return a larger list with the
+ 3x spatial crops as well.
+ """
+
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
+ super().__init__()
+ self.crop_size = crop_size
+ if num_crops == 3:
+ self.crops_to_ext = [0, 1, 2]
+ self.flipped_crops_to_ext = []
+ elif num_crops == 1:
+ self.crops_to_ext = [1]
+ self.flipped_crops_to_ext = []
+ else:
+ raise NotImplementedError("Nothing else supported yet")
+
+ def forward(self, videos):
+ """
+ Args:
+ videos: A list of C, T, H, W videos.
+ Returns:
+ videos: A list with 3x the number of elements. Each video converted
+ to C, T, H', W' by spatial cropping.
+ """
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
+ res = []
+ for video in videos:
+ for spatial_idx in self.crops_to_ext:
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
+ if not self.flipped_crops_to_ext:
+ continue
+ flipped_video = transforms.functional.hflip(video)
+ for spatial_idx in self.flipped_crops_to_ext:
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
+ return res
+
+
+def load_and_transform_video_data(
+ video_paths,
+ device,
+ clip_duration=2,
+ clips_per_video=5,
+ sample_rate=16000,
+):
+ if video_paths is None:
+ return None
+
+ video_outputs = []
+ video_transform = transforms.Compose(
+ [
+ pv_transforms.ShortSideScale(224),
+ NormalizeVideo(
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+ ),
+ ]
+ )
+
+ clip_sampler = ConstantClipsPerVideoSampler(
+ clip_duration=clip_duration, clips_per_video=clips_per_video
+ )
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
+
+ for video_path in video_paths:
+ video = EncodedVideo.from_path(
+ video_path,
+ decoder="decord",
+ decode_audio=False,
+ **{"sample_rate": sample_rate},
+ )
+
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
+
+ all_video = []
+ for clip_timepoints in all_clips_timepoints:
+ # Read the clip, get frames
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
+ if clip is None:
+ raise ValueError("No clip found")
+ video_clip = frame_sampler(clip["video"])
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
+
+ all_video.append(video_clip)
+
+ all_video = [video_transform(clip) for clip in all_video]
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
+
+ all_video = torch.stack(all_video, dim=0)
+ video_outputs.append(all_video)
+
+ return torch.stack(video_outputs, dim=0).to(device)
diff --git a/code/model/ImageBind/model_card.md b/code/model/ImageBind/model_card.md
new file mode 100644
index 0000000000000000000000000000000000000000..c7bb26500b6590b64ffa6350f37be80dc88612d8
--- /dev/null
+++ b/code/model/ImageBind/model_card.md
@@ -0,0 +1,94 @@
+# Model Card for ImageBind
+
+Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images.
+Input any of the six modalities and get the same sized embedding that can be used for cross-modal and multimodal tasks.
+
+# Model Details
+
+## Model Description
+
+
+Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images
+
+- **Developed by:** Meta AI
+- **Model type:** Multimodal model
+- **Language(s) (NLP):** en
+- **License:** CC BY-NC-SA 4.0
+- **Resources for more information:**
+ - [GitHub Repo](https://github.com/facebookresearch/ImageBind)
+
+
+# Uses
+
+
+This model is intended only for research purposes. It provides a joint embedding space for different modalities -- image/video, text, audio, depth, IMU and thermal images.
+We hope that these joint embeddings can be used for a variety of different cross-modal research, e.g., cross-modal retrieval and combining embeddings from different modalities.
+
+## Out-of-Scope Use
+
+
+
+
+This model is *NOT* intended to be used in any real world application -- commercial or otherwise.
+It may produce harmful associations with different inputs.
+The model needs to be investigated and likely re-trained on specific data for any such application.
+The model is expected to work better on web-based visual data since it was trained on such data.
+The text encoder is likely to work only on English language text because of the underlying training datasets.
+
+# Bias, Risks, and Limitations
+
+
+Open-domain joint embedding models are prone to producing specific biases, e.g., study from [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md#bias-and-fairness).
+Since our model uses such models as initialization, it will exhibit such biases too.
+Moreover, for learning joint embeddings for other modalities such as audio, thermal, depth, and IMU we leverage datasets that are relatively small. These joint embeddings are thus limited to the concepts present in the datasets. For example, the thermal datasets we used are limited to outdoor street scenes, while the depth datasets are limited to indoor scenes.
+
+
+
+# Training Details
+
+## Training Data
+
+
+
+ImageBind uses image-paired data for training -- (image, X) where X is one of text, audio, depth, IMU or thermal data.
+In particular, we initialize and freeze the image and text encoders using an OpenCLIP ViT-H encoder.
+We train audio embeddings using Audioset, depth embeddings using the SUN RGB-D dataset, IMU using the Ego4D dataset and thermal embeddings using the LLVIP dataset.
+We provide the exact training data details in the paper.
+
+
+## Training Procedure
+
+
+Please refer to the research paper and github repo for exact details on this.
+
+# Evaluation
+
+## Testing Data, Factors & Metrics
+
+We evaluate the model on a variety of different classification benchmarks for each modality.
+The evaluation details are presented in the paper.
+The models performance is measured using standard classification metrics such as accuracy and mAP.
+
+# Citation
+
+
+
+**BibTeX:**
+```
+@inproceedings{girdhar2023imagebind,
+ title={ImageBind: One Embedding Space To Bind Them All},
+ author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang
+and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan},
+ booktitle={CVPR},
+ year={2023}
+}
+```
+
+
+# Model Card Contact
+
+Please reach out to the authors at: rgirdhar@meta.com imisra@meta.com alaaelnouby@gmail.com
+
+# How to Get Started with the Model
+
+Our github repo provides a simple example to extract embeddings from images, audio etc.
diff --git a/code/model/ImageBind/models/__init__.py b/code/model/ImageBind/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/code/model/ImageBind/models/__pycache__/__init__.cpython-310.pyc b/code/model/ImageBind/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5791b4fed77d0dff0f8e83ab5141710a4983fe11
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/__init__.cpython-39.pyc b/code/model/ImageBind/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3de32cdbc7f15ded6813e73a695d430f6b07741
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/helpers.cpython-310.pyc b/code/model/ImageBind/models/__pycache__/helpers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdae1ede890b3942430efaf467dd5f009c45da25
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/helpers.cpython-310.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/helpers.cpython-39.pyc b/code/model/ImageBind/models/__pycache__/helpers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6fcb880a4c136832fb6b9e81dc3627e28f4b4cef
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/helpers.cpython-39.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/imagebind_model.cpython-310.pyc b/code/model/ImageBind/models/__pycache__/imagebind_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02f7dae44da4c150e561869d0130d345da92d66c
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/imagebind_model.cpython-310.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/imagebind_model.cpython-39.pyc b/code/model/ImageBind/models/__pycache__/imagebind_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ac9da7166ecdf308cdb553e574cb2904c42e9c1
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/imagebind_model.cpython-39.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/multimodal_preprocessors.cpython-310.pyc b/code/model/ImageBind/models/__pycache__/multimodal_preprocessors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87da63b1c69c0cdfe55dac1e9e6130bf2e8dc1a6
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/multimodal_preprocessors.cpython-310.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/multimodal_preprocessors.cpython-39.pyc b/code/model/ImageBind/models/__pycache__/multimodal_preprocessors.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..112fb5ad9a9c2be9a9f06749fe21e25fb252b1df
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/multimodal_preprocessors.cpython-39.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/transformer.cpython-310.pyc b/code/model/ImageBind/models/__pycache__/transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f3175a8cb119c1ce83e0e0d2b44afde199beb3b
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/transformer.cpython-310.pyc differ
diff --git a/code/model/ImageBind/models/__pycache__/transformer.cpython-39.pyc b/code/model/ImageBind/models/__pycache__/transformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0674473b53734e7d6d59ea6a6e01b279c1df6a9c
Binary files /dev/null and b/code/model/ImageBind/models/__pycache__/transformer.cpython-39.pyc differ
diff --git a/code/model/ImageBind/models/helpers.py b/code/model/ImageBind/models/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..049e1f1b0580832e8574350991bf347b6da81482
--- /dev/null
+++ b/code/model/ImageBind/models/helpers.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import einops
+import numpy as np
+import torch
+
+import torch.nn as nn
+
+
+class Normalize(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
+
+
+class LearnableLogitScaling(nn.Module):
+ def __init__(
+ self,
+ logit_scale_init: float = 1 / 0.07,
+ learnable: bool = True,
+ max_logit_scale: float = 100,
+ ) -> None:
+ super().__init__()
+ self.max_logit_scale = max_logit_scale
+ self.logit_scale_init = logit_scale_init
+ self.learnable = learnable
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
+ if learnable:
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
+ else:
+ self.register_buffer("log_logit_scale", log_logit_scale)
+
+ def forward(self, x):
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
+
+ def extra_repr(self):
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
+ return st
+
+
+class EinOpsRearrange(nn.Module):
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
+ super().__init__()
+ self.rearrange_expr = rearrange_expr
+ self.kwargs = kwargs
+
+ def forward(self, x):
+ assert isinstance(x, torch.Tensor)
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
+
+
+class VerboseNNModule(nn.Module):
+ """
+ Wrapper around nn.Module that prints registered buffers and parameter names.
+ """
+
+ @staticmethod
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
+ st = (
+ "("
+ + name
+ + "): "
+ + "tensor("
+ + str(tuple(tensor[1].shape))
+ + ", requires_grad="
+ + str(tensor[1].requires_grad)
+ + ")\n"
+ )
+ return st
+
+ def extra_repr(self) -> str:
+ named_modules = set()
+ for p in self.named_modules():
+ named_modules.update([p[0]])
+ named_modules = list(named_modules)
+
+ string_repr = ""
+ for p in self.named_parameters():
+ name = p[0].split(".")[0]
+ if name not in named_modules:
+ string_repr += self.get_readable_tensor_repr(name, p)
+
+ for p in self.named_buffers():
+ name = p[0].split(".")[0]
+ string_repr += self.get_readable_tensor_repr(name, p)
+
+ return string_repr
+
+
+def cast_if_src_dtype(
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
+):
+ updated = False
+ if tensor.dtype == src_dtype:
+ tensor = tensor.to(dtype=tgt_dtype)
+ updated = True
+ return tensor, updated
+
+
+class QuickGELU(nn.Module):
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class SelectElement(nn.Module):
+ def __init__(self, index) -> None:
+ super().__init__()
+ self.index = index
+
+ def forward(self, x):
+ assert x.ndim >= 3
+ return x[:, self.index, ...]
+
+
+class SelectEOSAndProject(nn.Module):
+ """
+ Text Pooling used in OpenCLIP
+ """
+
+ def __init__(self, proj: nn.Module) -> None:
+ super().__init__()
+ self.proj = proj
+
+ def forward(self, x, seq_len):
+ assert x.ndim == 3
+ # x is of shape B x L x D
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), seq_len]
+ x = self.proj(x)
+ return x
diff --git a/code/model/ImageBind/models/imagebind_model.py b/code/model/ImageBind/models/imagebind_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1981e8790b98131e2a89388142a79c6de94628
--- /dev/null
+++ b/code/model/ImageBind/models/imagebind_model.py
@@ -0,0 +1,521 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import os
+import urllib
+from functools import partial
+from types import SimpleNamespace
+
+import torch
+import torch.nn as nn
+
+from .helpers import (
+ EinOpsRearrange,
+ LearnableLogitScaling,
+ Normalize,
+ SelectElement,
+ SelectEOSAndProject,
+)
+from .multimodal_preprocessors import (
+ AudioPreprocessor,
+ IMUPreprocessor,
+ PadIm2Video,
+ PatchEmbedGeneric,
+ RGBDTPreprocessor,
+ SpatioTemporalPosEmbeddingHelper,
+ TextPreprocessor,
+ ThermalPreprocessor,
+)
+
+from .transformer import MultiheadAttention, SimpleTransformer
+
+
+ModalityType = SimpleNamespace(
+ VISION="vision",
+ TEXT="text",
+ AUDIO="audio",
+ THERMAL="thermal",
+ DEPTH="depth",
+ IMU="imu",
+)
+
+
+class ImageBindModel(nn.Module):
+ def __init__(
+ self,
+ video_frames=2,
+ kernel_size=(2, 14, 14),
+ audio_kernel_size=16,
+ audio_stride=10,
+ out_embed_dim=768,
+ vision_embed_dim=1024,
+ vision_num_blocks=24,
+ vision_num_heads=16,
+ audio_embed_dim=768,
+ audio_num_blocks=12,
+ audio_num_heads=12,
+ audio_num_mel_bins=128,
+ audio_target_len=204,
+ audio_drop_path=0.1,
+ text_embed_dim=768,
+ text_num_blocks=12,
+ text_num_heads=12,
+ depth_embed_dim=384,
+ depth_kernel_size=16,
+ depth_num_blocks=12,
+ depth_num_heads=8,
+ depth_drop_path=0.0,
+ thermal_embed_dim=768,
+ thermal_kernel_size=16,
+ thermal_num_blocks=12,
+ thermal_num_heads=12,
+ thermal_drop_path=0.0,
+ imu_embed_dim=512,
+ imu_kernel_size=8,
+ imu_num_blocks=6,
+ imu_num_heads=8,
+ imu_drop_path=0.7,
+ ):
+ super().__init__()
+
+ self.modality_preprocessors = self._create_modality_preprocessors(
+ video_frames,
+ vision_embed_dim,
+ kernel_size,
+ text_embed_dim,
+ audio_embed_dim,
+ audio_kernel_size,
+ audio_stride,
+ audio_num_mel_bins,
+ audio_target_len,
+ depth_embed_dim,
+ depth_kernel_size,
+ thermal_embed_dim,
+ thermal_kernel_size,
+ imu_embed_dim,
+ )
+
+ self.modality_trunks = self._create_modality_trunks(
+ vision_embed_dim,
+ vision_num_blocks,
+ vision_num_heads,
+ text_embed_dim,
+ text_num_blocks,
+ text_num_heads,
+ audio_embed_dim,
+ audio_num_blocks,
+ audio_num_heads,
+ audio_drop_path,
+ depth_embed_dim,
+ depth_num_blocks,
+ depth_num_heads,
+ depth_drop_path,
+ thermal_embed_dim,
+ thermal_num_blocks,
+ thermal_num_heads,
+ thermal_drop_path,
+ imu_embed_dim,
+ imu_num_blocks,
+ imu_num_heads,
+ imu_drop_path,
+ )
+
+ self.modality_heads = self._create_modality_heads(
+ out_embed_dim,
+ vision_embed_dim,
+ text_embed_dim,
+ audio_embed_dim,
+ depth_embed_dim,
+ thermal_embed_dim,
+ imu_embed_dim,
+ )
+
+ self.modality_postprocessors = self._create_modality_postprocessors(
+ out_embed_dim
+ )
+
+ def _create_modality_preprocessors(
+ self,
+ video_frames=2,
+ vision_embed_dim=1024,
+ kernel_size=(2, 14, 14),
+ text_embed_dim=768,
+ audio_embed_dim=768,
+ audio_kernel_size=16,
+ audio_stride=10,
+ audio_num_mel_bins=128,
+ audio_target_len=204,
+ depth_embed_dim=768,
+ depth_kernel_size=16,
+ thermal_embed_dim=768,
+ thermal_kernel_size=16,
+ imu_embed_dim=512,
+ ):
+ rgbt_stem = PatchEmbedGeneric(
+ proj_stem=[
+ PadIm2Video(pad_type="repeat", ntimes=2),
+ nn.Conv3d(
+ in_channels=3,
+ kernel_size=kernel_size,
+ out_channels=vision_embed_dim,
+ stride=kernel_size,
+ bias=False,
+ ),
+ ]
+ )
+ rgbt_preprocessor = RGBDTPreprocessor(
+ img_size=[3, video_frames, 224, 224],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ rgbt_stem=rgbt_stem,
+ depth_stem=None,
+ )
+
+ text_preprocessor = TextPreprocessor(
+ context_length=77,
+ vocab_size=49408,
+ embed_dim=text_embed_dim,
+ causal_masking=True,
+ )
+
+ audio_stem = PatchEmbedGeneric(
+ proj_stem=[
+ nn.Conv2d(
+ in_channels=1,
+ kernel_size=audio_kernel_size,
+ stride=audio_stride,
+ out_channels=audio_embed_dim,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
+ )
+ audio_preprocessor = AudioPreprocessor(
+ img_size=[1, audio_num_mel_bins, audio_target_len],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ audio_stem=audio_stem,
+ )
+
+ depth_stem = PatchEmbedGeneric(
+ [
+ nn.Conv2d(
+ kernel_size=depth_kernel_size,
+ in_channels=1,
+ out_channels=depth_embed_dim,
+ stride=depth_kernel_size,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
+ )
+
+ depth_preprocessor = RGBDTPreprocessor(
+ img_size=[1, 224, 224],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ rgbt_stem=None,
+ depth_stem=depth_stem,
+ )
+
+ thermal_stem = PatchEmbedGeneric(
+ [
+ nn.Conv2d(
+ kernel_size=thermal_kernel_size,
+ in_channels=1,
+ out_channels=thermal_embed_dim,
+ stride=thermal_kernel_size,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
+ )
+ thermal_preprocessor = ThermalPreprocessor(
+ img_size=[1, 224, 224],
+ num_cls_tokens=1,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ thermal_stem=thermal_stem,
+ )
+
+ imu_stem = PatchEmbedGeneric(
+ [
+ nn.Linear(
+ in_features=48,
+ out_features=imu_embed_dim,
+ bias=False,
+ ),
+ ],
+ norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
+ )
+
+ imu_preprocessor = IMUPreprocessor(
+ img_size=[6, 2000],
+ num_cls_tokens=1,
+ kernel_size=8,
+ embed_dim=imu_embed_dim,
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
+ imu_stem=imu_stem,
+ )
+
+ modality_preprocessors = {
+ ModalityType.VISION: rgbt_preprocessor,
+ ModalityType.TEXT: text_preprocessor,
+ ModalityType.AUDIO: audio_preprocessor,
+ ModalityType.DEPTH: depth_preprocessor,
+ ModalityType.THERMAL: thermal_preprocessor,
+ ModalityType.IMU: imu_preprocessor,
+ }
+
+ return nn.ModuleDict(modality_preprocessors)
+
+ def _create_modality_trunks(
+ self,
+ vision_embed_dim=1024,
+ vision_num_blocks=24,
+ vision_num_heads=16,
+ text_embed_dim=768,
+ text_num_blocks=12,
+ text_num_heads=12,
+ audio_embed_dim=768,
+ audio_num_blocks=12,
+ audio_num_heads=12,
+ audio_drop_path=0.0,
+ depth_embed_dim=768,
+ depth_num_blocks=12,
+ depth_num_heads=12,
+ depth_drop_path=0.0,
+ thermal_embed_dim=768,
+ thermal_num_blocks=12,
+ thermal_num_heads=12,
+ thermal_drop_path=0.0,
+ imu_embed_dim=512,
+ imu_num_blocks=6,
+ imu_num_heads=8,
+ imu_drop_path=0.7,
+ ):
+ def instantiate_trunk(
+ embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
+ ):
+ return SimpleTransformer(
+ embed_dim=embed_dim,
+ num_blocks=num_blocks,
+ ffn_dropout_rate=0.0,
+ drop_path_rate=drop_path,
+ attn_target=partial(
+ MultiheadAttention,
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ bias=True,
+ add_bias_kv=add_bias_kv,
+ ),
+ pre_transformer_layer=nn.Sequential(
+ nn.LayerNorm(embed_dim, eps=1e-6)
+ if pre_transformer_ln
+ else nn.Identity(),
+ EinOpsRearrange("b l d -> l b d"),
+ ),
+ post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
+ )
+
+ modality_trunks = {}
+ modality_trunks[ModalityType.VISION] = instantiate_trunk(
+ vision_embed_dim,
+ vision_num_blocks,
+ vision_num_heads,
+ pre_transformer_ln=True,
+ add_bias_kv=False,
+ drop_path=0.0,
+ )
+ modality_trunks[ModalityType.TEXT] = instantiate_trunk(
+ text_embed_dim,
+ text_num_blocks,
+ text_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=False,
+ drop_path=0.0,
+ )
+ modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
+ audio_embed_dim,
+ audio_num_blocks,
+ audio_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=audio_drop_path,
+ )
+ modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
+ depth_embed_dim,
+ depth_num_blocks,
+ depth_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=depth_drop_path,
+ )
+ modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
+ thermal_embed_dim,
+ thermal_num_blocks,
+ thermal_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=thermal_drop_path,
+ )
+ modality_trunks[ModalityType.IMU] = instantiate_trunk(
+ imu_embed_dim,
+ imu_num_blocks,
+ imu_num_heads,
+ pre_transformer_ln=False,
+ add_bias_kv=True,
+ drop_path=imu_drop_path,
+ )
+
+ return nn.ModuleDict(modality_trunks)
+
+ def _create_modality_heads(
+ self,
+ out_embed_dim,
+ vision_embed_dim,
+ text_embed_dim,
+ audio_embed_dim,
+ depth_embed_dim,
+ thermal_embed_dim,
+ imu_embed_dim,
+ ):
+ modality_heads = {}
+
+ modality_heads[ModalityType.VISION] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
+ proj=nn.Sequential(
+ nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
+ nn.Linear(text_embed_dim, out_embed_dim, bias=False),
+ )
+ )
+
+ modality_heads[ModalityType.AUDIO] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.DEPTH] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.THERMAL] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
+ )
+
+ modality_heads[ModalityType.IMU] = nn.Sequential(
+ nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
+ SelectElement(index=0),
+ nn.Dropout(p=0.5),
+ nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
+ )
+
+ return nn.ModuleDict(modality_heads)
+
+ def _create_modality_postprocessors(self, out_embed_dim):
+ modality_postprocessors = {}
+
+ modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
+ modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
+ Normalize(dim=-1), LearnableLogitScaling(learnable=True)
+ )
+ modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
+ )
+ modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
+ )
+ modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
+ )
+ modality_postprocessors[ModalityType.IMU] = nn.Sequential(
+ Normalize(dim=-1),
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
+ )
+ return nn.ModuleDict(modality_postprocessors)
+
+ def forward(self, inputs):
+ outputs = {}
+ for modality_key, modality_value in inputs.items():
+ reduce_list = (
+ modality_value.ndim >= 5
+ ) # Audio and Video inputs consist of multiple clips
+ if reduce_list:
+ B, S = modality_value.shape[:2]
+ modality_value = modality_value.reshape(
+ B * S, *modality_value.shape[2:]
+ )
+
+ if modality_value is not None:
+ modality_value = self.modality_preprocessors[modality_key](
+ **{modality_key: modality_value}
+ )
+ trunk_inputs = modality_value["trunk"]
+ head_inputs = modality_value["head"]
+ modality_value = self.modality_trunks[modality_key](**trunk_inputs)
+ modality_value = self.modality_heads[modality_key](
+ modality_value, **head_inputs
+ )
+ if modality_key in [ModalityType.AUDIO]:
+ modality_value = self.modality_postprocessors[modality_key][0](
+ modality_value
+ )
+ else:
+ modality_value = self.modality_postprocessors[modality_key](
+ modality_value
+ )
+
+ if reduce_list:
+ modality_value = modality_value.reshape(B, S, -1)
+ modality_value = modality_value.mean(dim=1)
+
+ outputs[modality_key] = modality_value
+
+ return outputs
+
+
+def imagebind_huge(pretrained=False, store_path=r'.checkpoints'):
+ model = ImageBindModel(
+ vision_embed_dim=1280,
+ vision_num_blocks=32,
+ vision_num_heads=16,
+ text_embed_dim=1024,
+ text_num_blocks=24,
+ text_num_heads=16,
+ out_embed_dim=1024,
+ audio_drop_path=0.1,
+ imu_drop_path=0.7,
+ )
+
+ if pretrained:
+ if not os.path.exists("{}/imagebind_huge.pth".format(store_path)):
+ print(
+ "Downloading imagebind weights to {}/imagebind_huge.pth ...".format(store_path)
+ )
+ os.makedirs(store_path, exist_ok=True)
+ torch.hub.download_url_to_file(
+ "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
+ "{}/imagebind_huge.pth".format(store_path),
+ progress=True,
+ )
+
+ model.load_state_dict(torch.load("{}/imagebind_huge.pth".format(store_path)))
+
+ return model, 1024
diff --git a/code/model/ImageBind/models/multimodal_preprocessors.py b/code/model/ImageBind/models/multimodal_preprocessors.py
new file mode 100644
index 0000000000000000000000000000000000000000..44de961053601fd288c5c92c56b799d5762b8b4c
--- /dev/null
+++ b/code/model/ImageBind/models/multimodal_preprocessors.py
@@ -0,0 +1,687 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import gzip
+import html
+import io
+import math
+from functools import lru_cache
+from typing import Callable, List, Optional
+
+import ftfy
+
+import numpy as np
+import regex as re
+import torch
+import torch.nn as nn
+from iopath.common.file_io import g_pathmgr
+from timm.models.layers import trunc_normal_
+
+from .helpers import cast_if_src_dtype, VerboseNNModule
+
+
+def get_sinusoid_encoding_table(n_position, d_hid):
+ """Sinusoid position encoding table"""
+
+ # TODO: make it with torch instead of numpy
+ def get_position_angle_vec(position):
+ return [
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
+ for hid_j in range(d_hid)
+ ]
+
+ sinusoid_table = np.array(
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
+ )
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+
+def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
+ N = pos_embed.shape[1]
+ if N == target_spatial_size:
+ return pos_embed
+ dim = pos_embed.shape[-1]
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
+ pos_embed = nn.functional.interpolate(
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
+ 0, 3, 1, 2
+ ),
+ scale_factor=math.sqrt(target_spatial_size / N),
+ mode="bicubic",
+ )
+ if updated:
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return pos_embed
+
+
+def interpolate_pos_encoding(
+ npatch_per_img,
+ pos_embed,
+ patches_layout,
+ input_shape=None,
+ first_patch_idx=1,
+):
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
+ if npatch_per_img == N:
+ return pos_embed
+
+ assert (
+ patches_layout[-1] == patches_layout[-2]
+ ), "Interpolation of pos embed not supported for non-square layouts"
+
+ class_emb = pos_embed[:, :first_patch_idx]
+ pos_embed = pos_embed[:, first_patch_idx:]
+
+ if input_shape is None or patches_layout[0] == 1:
+ # simple 2D pos embedding, no temporal component
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
+ elif patches_layout[0] > 1:
+ # pos embed has a temporal component
+ assert len(input_shape) == 4, "temporal interpolation not supported"
+ # we only support 2D interpolation in this case
+ num_frames = patches_layout[0]
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
+ # interpolate embedding for zeroth frame
+ pos_embed = interpolate_pos_encoding_2d(
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
+ )
+ else:
+ raise ValueError("This type of interpolation isn't implemented")
+
+ return torch.cat((class_emb, pos_embed), dim=1)
+
+
+def _get_pos_embedding(
+ npatch_per_img,
+ pos_embed,
+ patches_layout,
+ input_shape,
+ first_patch_idx=1,
+):
+ pos_embed = interpolate_pos_encoding(
+ npatch_per_img,
+ pos_embed,
+ patches_layout,
+ input_shape=input_shape,
+ first_patch_idx=first_patch_idx,
+ )
+ return pos_embed
+
+
+class PatchEmbedGeneric(nn.Module):
+ """
+ PatchEmbed from Hydra
+ """
+
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
+ super().__init__()
+
+ if len(proj_stem) > 1:
+ self.proj = nn.Sequential(*proj_stem)
+ else:
+ # Special case to be able to load pre-trained models that were
+ # trained with a standard stem
+ self.proj = proj_stem[0]
+ self.norm_layer = norm_layer
+
+ def get_patch_layout(self, img_size):
+ with torch.no_grad():
+ dummy_img = torch.zeros(
+ [
+ 1,
+ ]
+ + img_size
+ )
+ dummy_out = self.proj(dummy_img)
+ embed_dim = dummy_out.shape[1]
+ patches_layout = tuple(dummy_out.shape[2:])
+ num_patches = np.prod(patches_layout)
+ return patches_layout, num_patches, embed_dim
+
+ def forward(self, x):
+ x = self.proj(x)
+ # B C (T) H W -> B (T)HW C
+ x = x.flatten(2).transpose(1, 2)
+ if self.norm_layer is not None:
+ x = self.norm_layer(x)
+ return x
+
+
+class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
+ def __init__(
+ self,
+ patches_layout: List,
+ num_patches: int,
+ num_cls_tokens: int,
+ embed_dim: int,
+ learnable: bool,
+ ) -> None:
+ super().__init__()
+ self.num_cls_tokens = num_cls_tokens
+ self.patches_layout = patches_layout
+ self.num_patches = num_patches
+ self.num_tokens = num_cls_tokens + num_patches
+ self.learnable = learnable
+ if self.learnable:
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
+ trunc_normal_(self.pos_embed, std=0.02)
+ else:
+ self.register_buffer(
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
+ )
+
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
+ input_shape = vision_input.shape
+ pos_embed = _get_pos_embedding(
+ all_vision_tokens.size(1) - self.num_cls_tokens,
+ pos_embed=self.pos_embed,
+ patches_layout=self.patches_layout,
+ input_shape=input_shape,
+ first_patch_idx=self.num_cls_tokens,
+ )
+ return pos_embed
+
+
+class RGBDTPreprocessor(VerboseNNModule):
+ def __init__(
+ self,
+ rgbt_stem: PatchEmbedGeneric,
+ depth_stem: PatchEmbedGeneric,
+ img_size: List = (3, 224, 224),
+ num_cls_tokens: int = 1,
+ pos_embed_fn: Callable = None,
+ use_type_embed: bool = False,
+ init_param_style: str = "openclip",
+ ) -> None:
+ super().__init__()
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
+ (
+ self.patches_layout,
+ self.num_patches,
+ self.embed_dim,
+ ) = stem.get_patch_layout(img_size)
+ self.rgbt_stem = rgbt_stem
+ self.depth_stem = depth_stem
+ self.use_pos_embed = pos_embed_fn is not None
+ self.use_type_embed = use_type_embed
+ self.num_cls_tokens = num_cls_tokens
+
+ if self.use_pos_embed:
+ self.pos_embedding_helper = pos_embed_fn(
+ patches_layout=self.patches_layout,
+ num_cls_tokens=num_cls_tokens,
+ num_patches=self.num_patches,
+ embed_dim=self.embed_dim,
+ )
+ if self.num_cls_tokens > 0:
+ self.cls_token = nn.Parameter(
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
+ )
+ if self.use_type_embed:
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+ self.init_parameters(init_param_style)
+
+ @torch.no_grad()
+ def init_parameters(self, init_param_style):
+ if init_param_style == "openclip":
+ # OpenCLIP style initialization
+ scale = self.embed_dim**-0.5
+ if self.use_pos_embed:
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
+ self.pos_embedding_helper.pos_embed *= scale
+
+ if self.num_cls_tokens > 0:
+ nn.init.normal_(self.cls_token)
+ self.cls_token *= scale
+ elif init_param_style == "vit":
+ self.cls_token.data.fill_(0)
+ else:
+ raise ValueError(f"Unknown init {init_param_style}")
+
+ if self.use_type_embed:
+ nn.init.normal_(self.type_embed)
+
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
+ # tokens is of shape B x L x D
+ tokens = stem(input)
+ assert tokens.ndim == 3
+ assert tokens.shape[2] == self.embed_dim
+ B = tokens.shape[0]
+ if self.num_cls_tokens > 0:
+ class_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole class_tokens impl from Phil Wang, thanks
+ tokens = torch.cat((class_tokens, tokens), dim=1)
+ if self.use_pos_embed:
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
+ tokens = tokens + pos_embed
+ if self.use_type_embed:
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
+ return tokens
+
+ def forward(self, vision=None, depth=None, patch_mask=None):
+ if patch_mask is not None:
+ raise NotImplementedError()
+
+ if vision is not None:
+ vision_tokens = self.tokenize_input_and_cls_pos(
+ vision, self.rgbt_stem, patch_mask
+ )
+
+ if depth is not None:
+ depth_tokens = self.tokenize_input_and_cls_pos(
+ depth, self.depth_stem, patch_mask
+ )
+
+ # aggregate tokens
+ if vision is not None and depth is not None:
+ final_tokens = vision_tokens + depth_tokens
+ else:
+ final_tokens = vision_tokens if vision is not None else depth_tokens
+ return_dict = {
+ "trunk": {
+ "tokens": final_tokens,
+ },
+ "head": {},
+ }
+ return return_dict
+
+
+class AudioPreprocessor(RGBDTPreprocessor):
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
+
+ def forward(self, audio=None):
+ return super().forward(vision=audio)
+
+
+class ThermalPreprocessor(RGBDTPreprocessor):
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
+
+ def forward(self, thermal=None):
+ return super().forward(vision=thermal)
+
+
+def build_causal_attention_mask(context_length):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(context_length, context_length, requires_grad=False)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+
+class TextPreprocessor(VerboseNNModule):
+ def __init__(
+ self,
+ vocab_size: int,
+ context_length: int,
+ embed_dim: int,
+ causal_masking: bool,
+ supply_seq_len_to_head: bool = True,
+ num_cls_tokens: int = 0,
+ init_param_style: str = "openclip",
+ ) -> None:
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.context_length = context_length
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
+ self.pos_embed = nn.Parameter(
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
+ )
+ self.causal_masking = causal_masking
+ if self.causal_masking:
+ mask = build_causal_attention_mask(self.context_length)
+ # register the mask as a buffer so it can be moved to the right device
+ self.register_buffer("mask", mask)
+
+ self.supply_seq_len_to_head = supply_seq_len_to_head
+ self.num_cls_tokens = num_cls_tokens
+ self.embed_dim = embed_dim
+ if num_cls_tokens > 0:
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
+ self.cls_token = nn.Parameter(
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
+ )
+
+ self.init_parameters(init_param_style)
+
+ @torch.no_grad()
+ def init_parameters(self, init_param_style="openclip"):
+ # OpenCLIP style initialization
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.pos_embed, std=0.01)
+
+ if init_param_style == "openclip":
+ # OpenCLIP style initialization
+ scale = self.embed_dim**-0.5
+ if self.num_cls_tokens > 0:
+ nn.init.normal_(self.cls_token)
+ self.cls_token *= scale
+ elif init_param_style == "vit":
+ self.cls_token.data.fill_(0)
+ else:
+ raise ValueError(f"Unknown init {init_param_style}")
+
+ def forward(self, text):
+ # text tokens are of shape B x L x D
+ text_tokens = self.token_embedding(text)
+ # concat CLS tokens if any
+ if self.num_cls_tokens > 0:
+ B = text_tokens.shape[0]
+ class_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole class_tokens impl from Phil Wang, thanks
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
+ text_tokens = text_tokens + self.pos_embed
+ return_dict = {
+ "trunk": {
+ "tokens": text_tokens,
+ },
+ "head": {},
+ }
+ # Compute sequence length after adding CLS tokens
+ if self.supply_seq_len_to_head:
+ text_lengths = text.argmax(dim=-1)
+ return_dict["head"] = {
+ "seq_len": text_lengths,
+ }
+ if self.causal_masking:
+ return_dict["trunk"].update({"attn_mask": self.mask})
+ return return_dict
+
+
+class Im2Video(nn.Module):
+ """Convert an image into a trivial video."""
+
+ def __init__(self, time_dim=2):
+ super().__init__()
+ self.time_dim = time_dim
+
+ def forward(self, x):
+ if x.ndim == 4:
+ # B, C, H, W -> B, C, T, H, W
+ return x.unsqueeze(self.time_dim)
+ elif x.ndim == 5:
+ return x
+ else:
+ raise ValueError(f"Dimension incorrect {x.shape}")
+
+
+class PadIm2Video(Im2Video):
+ def __init__(self, ntimes, pad_type, time_dim=2):
+ super().__init__(time_dim=time_dim)
+ assert ntimes > 0
+ assert pad_type in ["zero", "repeat"]
+ self.ntimes = ntimes
+ self.pad_type = pad_type
+
+ def forward(self, x):
+ x = super().forward(x)
+ if x.shape[self.time_dim] == 1:
+ if self.pad_type == "repeat":
+ new_shape = [1] * len(x.shape)
+ new_shape[self.time_dim] = self.ntimes
+ x = x.repeat(new_shape)
+ elif self.pad_type == "zero":
+ padarg = [0, 0] * len(x.shape)
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
+ x = nn.functional.pad(x, padarg)
+ return x
+
+
+# Modified from github.com/openai/CLIP
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str, context_length=77):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+
+ with g_pathmgr.open(bpe_path, "rb") as fh:
+ bpe_bytes = io.BytesIO(fh.read())
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
+ merges = merges[1 : 49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + "" for v in vocab]
+ for merge in merges:
+ vocab.append("".join(merge))
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {
+ "<|startoftext|>": "<|startoftext|>",
+ "<|endoftext|>": "<|endoftext|>",
+ }
+ self.pat = re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE,
+ )
+ self.context_length = context_length
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+ )
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = "".join([self.decoder[token] for token in tokens])
+ text = (
+ bytearray([self.byte_decoder[c] for c in text])
+ .decode("utf-8", errors="replace")
+ .replace("", " ")
+ )
+ return text
+
+ def __call__(self, texts, context_length=None):
+ if not context_length:
+ context_length = self.context_length
+
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self.encoder["<|startoftext|>"]
+ eot_token = self.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ tokens = tokens[:context_length]
+ result[i, : len(tokens)] = torch.tensor(tokens)
+
+ if len(result) == 1:
+ return result[0]
+ return result
+
+
+class IMUPreprocessor(VerboseNNModule):
+ def __init__(
+ self,
+ kernel_size: int,
+ imu_stem: PatchEmbedGeneric,
+ embed_dim: int,
+ img_size: List = (6, 2000),
+ num_cls_tokens: int = 1,
+ pos_embed_fn: Callable = None,
+ init_param_style: str = "openclip",
+ ) -> None:
+ super().__init__()
+ stem = imu_stem
+ self.imu_stem = imu_stem
+ self.embed_dim = embed_dim
+ self.use_pos_embed = pos_embed_fn is not None
+ self.num_cls_tokens = num_cls_tokens
+ self.kernel_size = kernel_size
+ self.pos_embed = nn.Parameter(
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
+ )
+
+ if self.num_cls_tokens > 0:
+ self.cls_token = nn.Parameter(
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
+ )
+
+ self.init_parameters(init_param_style)
+
+ @torch.no_grad()
+ def init_parameters(self, init_param_style):
+ nn.init.normal_(self.pos_embed, std=0.01)
+
+ if init_param_style == "openclip":
+ # OpenCLIP style initialization
+ scale = self.embed_dim**-0.5
+
+ if self.num_cls_tokens > 0:
+ nn.init.normal_(self.cls_token)
+ self.cls_token *= scale
+ elif init_param_style == "vit":
+ self.cls_token.data.fill_(0)
+ else:
+ raise ValueError(f"Unknown init {init_param_style}")
+
+ def tokenize_input_and_cls_pos(self, input, stem):
+ # tokens is of shape B x L x D
+ tokens = stem.norm_layer(stem.proj(input))
+ assert tokens.ndim == 3
+ assert tokens.shape[2] == self.embed_dim
+ B = tokens.shape[0]
+ if self.num_cls_tokens > 0:
+ class_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole class_tokens impl from Phil Wang, thanks
+ tokens = torch.cat((class_tokens, tokens), dim=1)
+ if self.use_pos_embed:
+ tokens = tokens + self.pos_embed
+ return tokens
+
+ def forward(self, imu):
+ # Patchify
+ imu = imu.unfold(
+ -1,
+ self.kernel_size,
+ self.kernel_size,
+ ).permute(0, 2, 1, 3)
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
+
+ imu_tokens = self.tokenize_input_and_cls_pos(
+ imu,
+ self.imu_stem,
+ )
+
+ return_dict = {
+ "trunk": {
+ "tokens": imu_tokens,
+ },
+ "head": {},
+ }
+ return return_dict
diff --git a/code/model/ImageBind/models/transformer.py b/code/model/ImageBind/models/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..98902ac8f08868c486a7c74781e952bee444c2e6
--- /dev/null
+++ b/code/model/ImageBind/models/transformer.py
@@ -0,0 +1,284 @@
+#!/usr/bin/env python3
+# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Code modified from
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
+# https://github.com/facebookresearch/deit/blob/main/models.py
+# and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
+
+
+import copy
+import fnmatch
+import logging
+from functools import partial
+from typing import Callable, List
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import DropPath, trunc_normal_
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version,
+ # can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class MultiheadAttention(nn.MultiheadAttention):
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
+ return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+
+class ViTAttention(Attention):
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
+ assert attn_mask is None
+ return super().forward(x)
+
+
+class BlockWithMasking(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ attn_target: Callable,
+ mlp_ratio: int = 4,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ ffn_dropout_rate: float = 0.0,
+ drop_path: float = 0.0,
+ layer_scale_type: str = None,
+ layer_scale_init_value: float = 1e-4,
+ ):
+ super().__init__()
+
+ assert not isinstance(
+ attn_target, nn.Module
+ ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
+ self.attn = attn_target()
+ if drop_path > 0.0:
+ self.drop_path = DropPath(drop_path)
+ else:
+ self.drop_path = nn.Identity()
+ self.norm_1 = norm_layer(dim)
+ mlp_hidden_dim = int(mlp_ratio * dim)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=ffn_dropout_rate,
+ )
+ self.norm_2 = norm_layer(dim)
+ self.layer_scale_type = layer_scale_type
+ if self.layer_scale_type is not None:
+ assert self.layer_scale_type in [
+ "per_channel",
+ "scalar",
+ ], f"Found Layer scale type {self.layer_scale_type}"
+ if self.layer_scale_type == "per_channel":
+ # one gamma value per channel
+ gamma_shape = [1, 1, dim]
+ elif self.layer_scale_type == "scalar":
+ # single gamma value for all channels
+ gamma_shape = [1, 1, 1]
+ # two gammas: for each part of the fwd in the encoder
+ self.layer_scale_gamma1 = nn.Parameter(
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
+ requires_grad=True,
+ )
+ self.layer_scale_gamma2 = nn.Parameter(
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
+ requires_grad=True,
+ )
+
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
+ if self.layer_scale_type is None:
+ x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
+ x = x + self.drop_path(self.mlp(self.norm_2(x)))
+ else:
+ x = (
+ x
+ + self.drop_path(self.attn(self.norm_1(x), attn_mask))
+ * self.layer_scale_gamma1
+ )
+ x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
+ return x
+
+
+_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
+
+
+class SimpleTransformer(nn.Module):
+ def __init__(
+ self,
+ attn_target: Callable,
+ embed_dim: int,
+ num_blocks: int,
+ block: Callable = BlockWithMasking,
+ pre_transformer_layer: Callable = None,
+ post_transformer_layer: Callable = None,
+ drop_path_rate: float = 0.0,
+ drop_path_type: str = "progressive",
+ norm_layer: Callable = _LAYER_NORM,
+ mlp_ratio: int = 4,
+ ffn_dropout_rate: float = 0.0,
+ layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
+ layer_scale_init_value: float = 1e-4, # from cait; float
+ weight_init_style: str = "jax", # possible values jax or pytorch
+ ):
+ """
+ Simple Transformer with the following features
+ 1. Supports masked attention
+ 2. Supports DropPath
+ 3. Supports LayerScale
+ 4. Supports Dropout in Attention and FFN
+ 5. Makes few assumptions about the input except that it is a Tensor
+ """
+ super().__init__()
+ self.pre_transformer_layer = pre_transformer_layer
+ if drop_path_type == "progressive":
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
+ elif drop_path_type == "uniform":
+ dpr = [drop_path_rate for i in range(num_blocks)]
+ else:
+ raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
+
+ self.blocks = nn.Sequential(
+ *[
+ block(
+ dim=embed_dim,
+ attn_target=attn_target,
+ mlp_ratio=mlp_ratio,
+ ffn_dropout_rate=ffn_dropout_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ layer_scale_type=layer_scale_type,
+ layer_scale_init_value=layer_scale_init_value,
+ )
+ for i in range(num_blocks)
+ ]
+ )
+ self.post_transformer_layer = post_transformer_layer
+ self.weight_init_style = weight_init_style
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ if self.weight_init_style == "jax":
+ # Based on MAE and official Jax ViT implementation
+ torch.nn.init.xavier_uniform_(m.weight)
+ elif self.weight_init_style == "pytorch":
+ # PyTorch ViT uses trunc_normal_
+ trunc_normal_(m.weight, std=0.02)
+
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, (nn.LayerNorm)):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ attn_mask: torch.Tensor = None,
+ use_checkpoint: bool = False,
+ checkpoint_every_n: int = 1,
+ checkpoint_blk_ids: List[int] = None,
+ ):
+ """
+ Inputs
+ - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
+ - attn: mask of shape L x L
+
+ Output
+ - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
+ """
+ if self.pre_transformer_layer:
+ tokens = self.pre_transformer_layer(tokens)
+ if use_checkpoint and checkpoint_blk_ids is None:
+ checkpoint_blk_ids = [
+ blk_id
+ for blk_id in range(len(self.blocks))
+ if blk_id % checkpoint_every_n == 0
+ ]
+ if checkpoint_blk_ids:
+ checkpoint_blk_ids = set(checkpoint_blk_ids)
+ for blk_id, blk in enumerate(self.blocks):
+ if use_checkpoint and blk_id in checkpoint_blk_ids:
+ tokens = checkpoint.checkpoint(
+ blk, tokens, attn_mask, use_reentrant=False
+ )
+ else:
+ tokens = blk(tokens, attn_mask=attn_mask)
+ if self.post_transformer_layer:
+ tokens = self.post_transformer_layer(tokens)
+ return tokens
diff --git a/code/model/ImageBind/requirements.txt b/code/model/ImageBind/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..572ae079a6cc3592552d93b8ca08c3ec7fd4efc9
--- /dev/null
+++ b/code/model/ImageBind/requirements.txt
@@ -0,0 +1,10 @@
+--extra-index-url https://download.pytorch.org/whl/cu113
+torchvision==0.14.0
+torchaudio==0.13.0
+pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d
+timm==0.6.7
+ftfy
+regex
+einops
+fvcore
+decord==0.6.0
diff --git a/code/model/__init__.py b/code/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..879752bcc2ad73a53bd786c665a995d722b4e56b
--- /dev/null
+++ b/code/model/__init__.py
@@ -0,0 +1,9 @@
+from .agent import DeepSpeedAgent
+from .openllama import OpenLLAMAPEFTModel
+
+def load_model(args):
+ agent_name = args['models'][args['model']]['agent_name']
+ model_name = args['models'][args['model']]['model_name']
+ model = globals()[model_name](**args)
+ agent = globals()[agent_name](model, args)
+ return agent
diff --git a/code/model/__pycache__/__init__.cpython-310.pyc b/code/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1173f59442c93b6379d83b0b521ba54161911f98
Binary files /dev/null and b/code/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/code/model/__pycache__/agent.cpython-310.pyc b/code/model/__pycache__/agent.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e7d9cec8e164d6fc0883ffa78aa8386411d7d35
Binary files /dev/null and b/code/model/__pycache__/agent.cpython-310.pyc differ
diff --git a/code/model/__pycache__/modeling_llama.cpython-310.pyc b/code/model/__pycache__/modeling_llama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e881f160c736b6c5b0b7b3387b21b4985428fc5
Binary files /dev/null and b/code/model/__pycache__/modeling_llama.cpython-310.pyc differ
diff --git a/code/model/__pycache__/openllama.cpython-310.pyc b/code/model/__pycache__/openllama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5096daa4a1a203fd47daa74366c8759354b8efe0
Binary files /dev/null and b/code/model/__pycache__/openllama.cpython-310.pyc differ
diff --git a/code/model/agent.py b/code/model/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..219001fda8cfa22bf6c1b720504c07763044a119
--- /dev/null
+++ b/code/model/agent.py
@@ -0,0 +1,68 @@
+from header import *
+
+class DeepSpeedAgent:
+
+ def __init__(self, model, args):
+ super(DeepSpeedAgent, self).__init__()
+ self.args = args
+ self.model = model
+ if args['stage'] == 2:
+ self.load_stage_1_parameters(args["delta_ckpt_path"])
+ print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}')
+
+ # load config parameters of deepspeed
+ ds_params = json.load(open(self.args['ds_config_path']))
+ ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps']
+ ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']))
+ self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize(
+ model=self.model,
+ model_parameters=self.model.parameters(),
+ config_params=ds_params,
+ dist_init_required=True,
+ args=types.SimpleNamespace(**args)
+ )
+
+ @torch.no_grad()
+ def predict(self, batch):
+ self.model.eval()
+ string = self.model.generate_one_sample(batch)
+ return string
+
+ def train_model(self, batch, current_step=0, pbar=None):
+ self.ds_engine.module.train()
+ loss, mle_acc = self.ds_engine(batch)
+
+ self.ds_engine.backward(loss)
+ self.ds_engine.step()
+ pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
+ pbar.update(1)
+ if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0:
+ elapsed = pbar.format_dict['elapsed']
+ rate = pbar.format_dict['rate']
+ remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
+ remaining = str(datetime.timedelta(seconds=remaining))
+ logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
+
+ mle_acc *= 100
+ return mle_acc
+
+ def save_model(self, path, current_step):
+ # only save trainable model parameters
+ param_grad_dic = {
+ k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters()
+ }
+ state_dict = self.ds_engine.module.state_dict()
+ checkpoint = OrderedDict()
+ for k, v in self.ds_engine.module.named_parameters():
+ if v.requires_grad:
+ checkpoint[k] = v
+ torch.save(checkpoint, f'{path}/pytorch_model.pt')
+ # save tokenizer
+ self.model.llama_tokenizer.save_pretrained(path)
+ # save configuration
+ self.model.llama_model.config.save_pretrained(path)
+ print(f'[!] save model into {path}')
+
+ def load_stage_1_parameters(self, path):
+ delta_ckpt = torch.load(path, map_location=torch.device('cpu'))
+ self.model.load_state_dict(delta_ckpt, strict=False)
diff --git a/code/model/modeling_llama.py b/code/model/modeling_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d980e189d902fb1a6d9ea05dc3ca91959b1c8c
--- /dev/null
+++ b/code/model/modeling_llama.py
@@ -0,0 +1,755 @@
+# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+
+""" PyTorch LLaMA model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LlamaAttention(config=config)
+ self.mlp = LlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlamaModel):
+ module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ query_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ if query_embeds is not None:
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
+ batch_size, seq_length, _ = inputs_embeds.shape
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ query_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ query_embeds=query_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+ query_embeds = None
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "query_embeds": query_embeds,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
diff --git a/code/model/openllama.py b/code/model/openllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8e235a95723942975ccdd88739bd7b5b79458e7
--- /dev/null
+++ b/code/model/openllama.py
@@ -0,0 +1,293 @@
+from header import *
+import torch.nn.functional as F
+from .ImageBind import *
+from .ImageBind import data
+from .modeling_llama import LlamaForCausalLM
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+import torch
+from torch.nn.utils import rnn
+
+class StoppingCriteriaSub(StoppingCriteria):
+
+ def __init__(self, stops = [], encounters=1):
+ super().__init__()
+ self.stops = stops
+ self.ENCOUNTERS = encounters
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ stop_count = 0
+ for stop in self.stops:
+ stop_count = (stop == input_ids[0]).sum().item()
+ if stop_count >= self.ENCOUNTERS:
+ return True
+ return False
+
+def build_one_instance(tokenizer, conversation):
+ text_list = []
+ turn_num = len(conversation)
+ input_ids, target_ids = [], []
+ for i in range(turn_num):
+ turn = conversation[i]
+ role = turn['from']
+ if i == 0: # the first human turn
+ assert role == 'human'
+ text = ' ' + turn['value'] + '\n### Assistant:'
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
+ input_ids += one_input_id
+ target_ids += [-100]*len(one_input_id) # do not perform loss regression on human prompt
+ else:
+ if role == 'human':
+ text = 'Human: ' + turn['value'] + '\n### Assistant:'
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
+ input_ids += one_input_id
+ target_ids += [-100]*len(one_input_id)
+ elif role == 'gpt':
+ text = turn['value'] + '\n###'
+ one_input_id = tokenizer(text, add_special_tokens=False).input_ids
+ input_ids += one_input_id
+ target_ids += one_input_id
+ else:
+ raise Exception('Wrong Role!!!')
+ text_list.append(text)
+ assert len(input_ids) == len(target_ids)
+ return text_list, input_ids, target_ids
+
+def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len):
+ batch_input_ids, batch_target_ids = [], []
+ for conversation in batch_of_conversations:
+ _, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation)
+ batch_input_ids.append(torch.LongTensor(one_input_ids))
+ batch_target_ids.append(torch.LongTensor(one_target_ids))
+ input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100)
+ assert input_ids.size() == target_ids.size()
+ input_ids = input_ids[:,:max_tgt_len]
+ target_ids = target_ids[:,:max_tgt_len]
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+ assert attention_mask.size() == input_ids.size()
+ return input_ids, target_ids, attention_mask.long()
+
+PROMPT_START = '### Human: '
+class OpenLLAMAPEFTModel(nn.Module):
+
+ '''LoRA for LLaMa model'''
+
+ def __init__(self, **args):
+ super(OpenLLAMAPEFTModel, self).__init__()
+ self.args = args
+ imagebind_ckpt_path = args['imagebind_ckpt_path']
+ vicuna_ckpt_path = args['vicuna_ckpt_path']
+ max_tgt_len = args['max_tgt_len']
+ stage = args['stage']
+
+ print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
+ self.visual_encoder, self.visual_hidden_size = \
+ imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
+ # free vision encoder
+ for name, param in self.visual_encoder.named_parameters():
+ param.requires_grad = False
+ self.visual_encoder.eval()
+ print ('Visual encoder initialized.')
+
+ print (f'Initializing language decoder from {vicuna_ckpt_path} ...')
+ # add the lora module
+ peft_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ inference_mode=False,
+ r=self.args['lora_r'],
+ lora_alpha=self.args['lora_alpha'],
+ lora_dropout=self.args['lora_dropout'],
+ target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
+ )
+
+ self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path)
+
+ self.llama_model = get_peft_model(self.llama_model, peft_config)
+ self.llama_model.print_trainable_parameters()
+
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False)
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
+ self.llama_tokenizer.padding_side = "right"
+ print ('Language decoder initialized.')
+
+ self.llama_proj = nn.Linear(
+ self.visual_hidden_size, self.llama_model.config.hidden_size
+ )
+
+ self.max_tgt_len = max_tgt_len
+ self.device = torch.cuda.current_device()
+
+ def encode_video(self, video_paths):
+ inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)}
+ # convert into visual dtype
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
+ with torch.no_grad():
+ embeddings = self.visual_encoder(inputs)
+ video_embeds = embeddings[ModalityType.VISION] # bsz x 1024
+ inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
+ return inputs_llama, atts_llama
+
+ def encode_audio(self, audio_paths):
+ inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)}
+ # convert into visual dtype
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
+ with torch.no_grad():
+ embeddings = self.visual_encoder(inputs)
+ audio_embeds = embeddings[ModalityType.AUDIO] # bsz x 1024
+ inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
+ return inputs_llama, atts_llama
+
+ def encode_thermal(self, thermal_paths):
+ inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)}
+ # convert into visual dtype
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
+ with torch.no_grad():
+ embeddings = self.visual_encoder(inputs)
+ image_embeds = embeddings['thermal'] # bsz x 1024
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
+ return inputs_llama, atts_llama
+
+ def encode_image(self, image_paths):
+ inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)}
+ # convert into visual dtype
+ inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs}
+ with torch.no_grad():
+ embeddings = self.visual_encoder(inputs)
+ image_embeds = embeddings['vision'] # bsz x 1024
+ inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1
+ return inputs_llama, atts_llama
+
+ def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask):
+ '''
+ input_ids, target_ids, attention_mask: bsz x s2
+ '''
+ input_ids = input_ids.to(self.device) # bsz x s2
+ target_ids = target_ids.to(self.device) # bsz x s2
+ attention_mask = attention_mask.to(self.device) # bsz x s2
+
+ batch_size = img_embeds.shape[0]
+ p_before = PROMPT_START
+ p_before_tokens = self.llama_tokenizer(p_before,
+ return_tensors="pt", add_special_tokens=False).to(self.device)
+ # peft model need deeper call
+ p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
+ p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim
+ bos = torch.ones([batch_size, 1],
+ dtype=p_before_tokens.input_ids.dtype,
+ device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
+ bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
+ inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim
+
+ # create targets
+ empty_targets = (
+ torch.ones([batch_size, 1+p_before_embeds.size()[1]+1], # 1 (bos) + s1 + 1 (image vector)
+ dtype=torch.long).to(self.device).fill_(-100)
+ ) # bsz x (1 + s1 + 1)
+ targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2)
+ assert inputs_embeds.size()[1] == targets.size()[1]
+
+ atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+1], dtype=torch.long).to(self.device) # bsz x (1 + s1 +1)
+ attention_mask = torch.cat([atts_prefix, attention_mask], dim=1)
+ assert attention_mask.size() == targets.size() # bsz x (1 + s1 + 1 + s2)
+ return inputs_embeds, targets, attention_mask
+
+ def forward(self, inputs):
+ image_paths = inputs['image_paths']
+ img_embeds, _ = self.encode_image(image_paths)
+
+ output_texts = inputs['output_texts']
+ input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len)
+ inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask)
+
+ outputs = self.llama_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ )
+ loss = outputs.loss
+ # calculate the token accuarcy
+ chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1]
+ labels = targets[:, 2:]
+ gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S]
+ valid_mask = (labels != -100).reshape(-1)
+ valid_tokens = gen_acc & valid_mask # [B*S]
+ gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
+ return loss, gen_acc
+
+ def extract_multimodal_feature(self, inputs):
+ features = []
+ if inputs['image_paths']:
+ image_embeds, _ = self.encode_image(inputs['image_paths'])
+ features.append(image_embeds)
+ if inputs['audio_paths']:
+ audio_embeds, _ = self.encode_audio(inputs['audio_paths'])
+ features.append(audio_embeds)
+ if inputs['video_paths']:
+ video_embeds, _ = self.encode_video(inputs['video_paths'])
+ features.append(video_embeds)
+ if inputs['thermal_paths']:
+ thermal_embeds, _ = self.encode_thermal(inputs['thermal_paths'])
+ features.append(thermal_embeds)
+
+ feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0)
+ return feature_embeds
+
+ def prepare_generation_embedding(self, inputs):
+ prompt = inputs['prompt']
+ if len(inputs['modality_embeds']) == 1:
+ feature_embeds = inputs['modality_embeds'][0]
+ else:
+ feature_embeds = self.extract_multimodal_feature(inputs)
+ inputs['modality_embeds'].append(feature_embeds)
+
+ batch_size = feature_embeds.shape[0]
+ p_before = PROMPT_START
+ p_before_tokens = self.llama_tokenizer(p_before,
+ return_tensors="pt", add_special_tokens=False).to(self.device)
+ p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
+ text = ' ' + prompt + '\n### Assistant:'
+ p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device)
+ p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim
+ bos = torch.ones([batch_size, 1],
+ dtype=p_before_tokens.input_ids.dtype,
+ device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1
+ bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim
+ inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim
+ return inputs_embeds
+
+ def generate(self, inputs):
+ '''
+ inputs = {
+ 'image_paths': optional,
+ 'audio_paths': optional
+ 'video_paths': optional
+ 'thermal_paths': optional
+ 'mode': generation mode,
+ 'prompt': human input prompt,
+ 'max_tgt_len': generation length,
+ 'top_p': top_p,
+ 'temperature': temperature
+ 'modality_embeds': None or torch.tensor
+ 'modality_cache': save the image cache
+ }
+ '''
+ input_embeds = self.prepare_generation_embedding(inputs)
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)])
+ outputs = self.llama_model.generate(
+ inputs_embeds=input_embeds,
+ max_new_tokens=inputs['max_tgt_len'],
+ top_p=inputs['top_p'],
+ temperature=inputs['temperature'],
+ do_sample=True,
+ use_cache=True,
+ stopping_criteria=stopping_criteria,
+ )
+ output_text = self.llama_tokenizer.decode(outputs[0][:-2], skip_special_tokens=True)
+ return output_text
+
diff --git a/code/pytorchvideo/.circleci/config.yml b/code/pytorchvideo/.circleci/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..df8133aa23db0742deb04692931b3460b51d30dc
--- /dev/null
+++ b/code/pytorchvideo/.circleci/config.yml
@@ -0,0 +1,205 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+# -------------------------------------------------------------------------------------
+# CircleCI configuration file.
+# Specifies automated environment setup and tests.
+#
+# See https://circleci.com/docs/2.0/language-python/ for more details
+# Available Machine Images:
+# https://circleci.com/docs/2.0/configuration-reference/#available-machine-images
+# -------------------------------------------------------------------------------------
+
+version: 2.1
+
+# -------------------------------------------------------------------------------------
+# Environments to run the jobs in
+# -------------------------------------------------------------------------------------
+cpu: &cpu
+ machine:
+ image: ubuntu-2004:202101-01
+
+gpu: &gpu
+ environment:
+ CUDA_VERSION: "10.2"
+ resource_class: gpu.nvidia.small.multi
+ machine:
+ image: ubuntu-2004:202101-01
+
+setup_cuda: &setup_cuda
+ run:
+ name: Setup CUDA
+ working_directory: ~/
+ command: |
+ # download and install nvidia drivers, cuda, etc
+ wget --no-verbose --no-clobber -P ~/nvidia-downloads https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run
+ sudo sh ~/nvidia-downloads/cuda_11.2.2_460.32.03_linux.run --silent
+ echo "Done installing CUDA."
+ nvidia-smi
+
+# -------------------------------------------------------------------------------------
+# Re-usable commands
+# -------------------------------------------------------------------------------------
+install_conda: &install_conda
+ run:
+ name: Setup Conda
+ working_directory: ~/
+ command: |
+ curl --retry 3 -o conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+ sh conda.sh -b -p $HOME/miniconda3
+
+setup_ptv_conda: &setup_ptv_conda
+ run:
+ name: Setup Conda Environment
+ command: |
+ pyenv versions
+ export PATH="$HOME/miniconda3/bin:$PATH"
+ conda update -y conda
+ conda init bash
+ source ~/.bashrc
+ conda create --name pytorchvideo python=3.7.9
+
+install_pytorch: &install_pytorch
+ - run:
+ name: Install Pytorch
+ command: |
+ export PATH="$HOME/miniconda3/bin:$PATH"
+ conda activate pytorchvideo
+ conda install pytorch torchvision -c pytorch
+ python -c 'import torch; print(torch.__version__)'
+ python -c 'import torch; print("CUDA:", torch.cuda.is_available())'
+ python -c 'import torchvision; print(torchvision.__version__)'
+
+install_pytorchvideo: &install_pytorchvideo
+ - run:
+ name: Install PyTorchVideo
+ command: |
+ export PATH="$HOME/miniconda3/bin:$PATH"
+ conda activate pytorchvideo
+ pip install -U --progress-bar off -e .[test]
+ python -c 'import pytorchvideo; print(pytorchvideo.__version__)'
+
+build_wheels: &build_wheels
+ - run:
+ name: Install PyTorchVideo
+ command: |
+ export PATH="$HOME/miniconda3/bin:$PATH"
+ conda activate pytorchvideo
+ python setup.py sdist
+
+ export BUILD_NIGHTLY="1"
+ python setup.py sdist
+
+run_unittests: &run_unittests
+ - run:
+ name: Run Unit Tests
+ command: |
+ export PATH="$HOME/miniconda3/bin:$PATH"
+ conda activate pytorchvideo
+ python -m unittest discover -v -s tests
+
+run_unittests_with_coverage: &run_unittests_with_coverage
+ - run:
+ name: Run Unit Tests
+ command: |
+ export PATH="$HOME/miniconda3/bin:$PATH"
+ conda activate pytorchvideo
+ coverage run -m unittest discover -v -s tests
+ bash <(curl -s https://codecov.io/bash)
+
+# -------------------------------------------------------------------------------------
+# Jobs to run
+# -------------------------------------------------------------------------------------
+jobs:
+ cpu_tests:
+ <<: *cpu
+ working_directory: ~/pytorchvideo
+ steps:
+ - checkout
+ - <<: *install_conda
+ - <<: *setup_ptv_conda
+ - <<: *install_pytorch
+ - <<: *install_pytorchvideo
+ - <<: *build_wheels
+ - <<: *run_unittests_with_coverage
+ - store_artifacts:
+ path: ~/pytorchvideo/dist
+ - persist_to_workspace:
+ root: ~/pytorchvideo/dist
+ paths:
+ - "*"
+
+ gpu_tests:
+ working_directory: ~/pytorchvideo
+ <<: *gpu
+ steps:
+ - checkout
+ - <<: *setup_cuda
+ - <<: *install_conda
+ - <<: *setup_ptv_conda
+ - <<: *install_pytorch
+ - <<: *install_pytorchvideo
+ - <<: *run_unittests
+
+ upload_wheel:
+ docker:
+ - image: circleci/python:3.7
+ auth:
+ username: $DOCKERHUB_USERNAME
+ password: $DOCKERHUB_TOKEN
+ working_directory: ~/pytorchvideo
+ steps:
+ - checkout
+ - attach_workspace:
+ at: ~/workspace
+ - run:
+ command: |
+ # no commits in the last 25 hours
+ if [[ -z $(git log --since="25 hours ago") ]]; then
+ echo "No commits in the last day."
+ exit 0
+ fi
+ pip install --progress-bar off --user twine
+ for pkg in ~/workspace/*.tar.gz; do
+ if [[ "$pkg" == *"nightly"* ]];
+ then
+ twine upload --verbose --skip-existing --username __token__ --password $PTV_NIGHTLY_PYPI_TOKEN $pkg
+ else
+ twine upload --verbose --skip-existing --username __token__ --password $PTV_PYPI_TOKEN $pkg
+ fi
+ done
+# -------------------------------------------------------------------------------------
+# Workflows to launch
+# -------------------------------------------------------------------------------------
+workflows:
+ version: 2
+ regular_test:
+ jobs:
+ - cpu_tests:
+ context:
+ - DOCKERHUB_TOKEN
+ - gpu_tests:
+ context:
+ - DOCKERHUB_TOKEN
+
+ nightly:
+ jobs:
+ # https://circleci.com/docs/2.0/contexts/#creating-and-using-a-context
+ - cpu_tests:
+ context:
+ - DOCKERHUB_TOKEN
+ - gpu_tests:
+ context:
+ - DOCKERHUB_TOKEN
+ - upload_wheel:
+ requires:
+ - cpu_tests
+ - gpu_tests
+ context:
+ - DOCKERHUB_TOKEN
+ triggers:
+ - schedule:
+ cron: "0 0 * * *"
+ filters:
+ branches:
+ only:
+ - main
diff --git a/code/pytorchvideo/.flake8 b/code/pytorchvideo/.flake8
new file mode 100644
index 0000000000000000000000000000000000000000..6c3b6d91f3dcf1baa1fc8e5f337fc469e0a9b0ae
--- /dev/null
+++ b/code/pytorchvideo/.flake8
@@ -0,0 +1,6 @@
+[flake8]
+ignore = E203, E266, E501, W503, E221
+max-line-length = 88
+max-complexity = 18
+select = B,C,E,F,W,T4,B9
+exclude = build,__init__.py
diff --git a/code/pytorchvideo/.github/CODE_OF_CONDUCT.md b/code/pytorchvideo/.github/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..f049d4c53173cc44e0d0755b874d108891a5bfc5
--- /dev/null
+++ b/code/pytorchvideo/.github/CODE_OF_CONDUCT.md
@@ -0,0 +1,76 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at . All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/code/pytorchvideo/.github/CONTRIBUTING.md b/code/pytorchvideo/.github/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..877c3814bd25d4d63608e8c3cd2942ed6f3def4d
--- /dev/null
+++ b/code/pytorchvideo/.github/CONTRIBUTING.md
@@ -0,0 +1,55 @@
+# Contributing to PyTorchVIdeo
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+We actively welcome your pull requests.
+
+However, if you're adding any significant features, please make sure to have a corresponding issue to outline your proposal and motivation and allow time for us to give feedback, *before* you send a PR.
+We do not always accept new features, and we take the following factors into consideration:
+
+- Whether the same feature can be achieved without modifying PyTorchVideo directly. If any aspect of the API is not extensible, please highlight this in an issue so we can work on making this more extensible.
+- Whether the feature is potentially useful to a large audience, or only to a small portion of users.
+- Whether the proposed solution has a good design and interface.
+- Whether the proposed solution adds extra mental/practical overhead to users who don't need such feature.
+- Whether the proposed solution breaks existing APIs.
+
+When sending a PR, please ensure you complete the following steps:
+
+1. Fork the repo and create your branch from `main`. Follow the instructions
+ in [INSTALL.md](../INSTALL.md) to build the repo.
+2. If you've added code that should be tested, add tests.
+3. If you've changed any APIs, please update the documentation.
+4. Ensure the test suite passes:
+ ```
+ cd pytorchvideo/tests
+ python -m unittest -v
+ ```
+5. Make sure your code lints by running `dev/linter.sh` from the project root.
+6. If a PR contains multiple orthogonal changes, split it into multiple separate PRs.
+7. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Facebook's open source projects.
+
+Complete your CLA here:
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## Coding Style
+We follow these [python](http://google.github.io/styleguide/pyguide.html) and [C++](https://google.github.io/styleguide/cppguide.html) style guides.
+
+For the linter to work, you will need to install `black`, `flake`, `isort` and `clang-format`, and
+they need to be fairly up to date.
+
+## License
+By contributing to PyTorchVideo, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree.
+
diff --git a/code/pytorchvideo/.github/ISSUE_TEMPLATE/bugs.md b/code/pytorchvideo/.github/ISSUE_TEMPLATE/bugs.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6ea6e9ffa4822b7c5a90eb25dc3123f6288939c
--- /dev/null
+++ b/code/pytorchvideo/.github/ISSUE_TEMPLATE/bugs.md
@@ -0,0 +1,30 @@
+---
+name: "🐛 Bugs / Unexpected behaviors"
+about: Please report unexpected behaviors or bugs in PyTorchVideo.
+
+---
+
+If you do not know the root cause of the problem / bug, and wish someone to help you, please
+post according to this template:
+
+## 🐛 Bugs / Unexpected behaviors
+
+
+NOTE: Please look at the existing list of Issues tagged with the label ['bug`](https://github.com/facebookresearch/pytorchvideo/issues?q=label%3Abug). **Only open a new issue if this bug has not already been reported. If an issue already exists, please comment there instead.**.
+
+## Instructions To Reproduce the Issue:
+
+Please include the following (depending on what the issue is):
+
+1. Any changes you made (`git diff`) or code you wrote
+```
+
+```
+2. The exact command(s) you ran:
+3. What you observed (including the full logs):
+```
+
+```
+
+Please also simplify the steps as much as possible so they do not require additional resources to
+ run, such as a private dataset, models, etc.
diff --git a/code/pytorchvideo/.github/ISSUE_TEMPLATE/config.yml b/code/pytorchvideo/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3ba13e0cec6cbbfd462e9ebf529dd2093148cd69
--- /dev/null
+++ b/code/pytorchvideo/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1 @@
+blank_issues_enabled: false
diff --git a/code/pytorchvideo/.github/ISSUE_TEMPLATE/feature_request.md b/code/pytorchvideo/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..4390d86b39837b7ada1af970067ac330aefa9bda
--- /dev/null
+++ b/code/pytorchvideo/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,21 @@
+---
+name: "\U0001F680 Feature Request"
+about: Submit a proposal/request for a new PyTorchVideo feature
+
+---
+
+## 🚀 Feature
+
+
+NOTE: Please look at the existing list of Issues tagged with the label ['enhancement`](https://github.com/facebookresearch/pytorchvideo/issues?q=label%3Aenhancement). **Only open a new issue if you do not see your feature request there**.
+
+## Motivation
+
+
+
+## Pitch
+
+
+
+NOTE: we only consider adding new features if they are useful for many users.
diff --git a/code/pytorchvideo/.github/ISSUE_TEMPLATE/questions-help.md b/code/pytorchvideo/.github/ISSUE_TEMPLATE/questions-help.md
new file mode 100644
index 0000000000000000000000000000000000000000..76bc0d4db2580d0fd50f3275620bf1fca0b03879
--- /dev/null
+++ b/code/pytorchvideo/.github/ISSUE_TEMPLATE/questions-help.md
@@ -0,0 +1,21 @@
+---
+name: "❓ Questions"
+about: How do I do X with PyTorchVideo? How does PyTorchVideo do X?
+
+---
+
+## ❓ Questions on how to use PyTorchVideo
+
+
+
+
+NOTE: Please look at the existing list of Issues tagged with the label ['question`](https://github.com/facebookresearch/pytorchvideo/issues?q=label%3Aquestion) or ['how-to`](https://github.com/facebookresearch/pytorchvideo/issues?q=label%3A%22how+to%22). **Only open a new issue if you cannot find an answer there**.
+
+Also note the following:
+
+1. If you encountered any errors or unexpected issues while using PyTorchVideo and need help resolving them,
+ please use the "Bugs / Unexpected behaviors" issue template.
+
+2. We do not answer general machine learning / computer vision questions that are not specific to
+ PyTorchVideo, such as how a model works or what algorithm/methods can be
+ used to achieve X.
diff --git a/code/pytorchvideo/.github/PULL_REQUEST_TEMPLATE.md b/code/pytorchvideo/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6851e7bdcaae99ab931cc183cf22fa060d9920d
--- /dev/null
+++ b/code/pytorchvideo/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,30 @@
+## Motivation and Context
+
+
+
+
+
+## How Has This Been Tested
+
+
+
+## Types of changes
+
+
+- [ ] Docs change / refactoring / dependency upgrade
+- [ ] Bug fix (non-breaking change which fixes an issue)
+- [ ] New feature (non-breaking change which adds functionality)
+- [ ] Breaking change (fix or feature that would cause existing functionality to change)
+
+## Checklist
+
+
+
+- [ ] My code follows the code style of this project.
+- [ ] My change requires a change to the documentation.
+- [ ] I have updated the documentation accordingly.
+- [ ] I have read the **CONTRIBUTING** document.
+- [ ] I have completed my CLA (see **CONTRIBUTING**)
+- [ ] I have added tests to cover my changes.
+- [ ] All new and existing tests passed.
+
diff --git a/code/pytorchvideo/.github/media/ava_slowfast.gif b/code/pytorchvideo/.github/media/ava_slowfast.gif
new file mode 100644
index 0000000000000000000000000000000000000000..37d427d2730de52acaf1200e88d638c7ccccb05a
--- /dev/null
+++ b/code/pytorchvideo/.github/media/ava_slowfast.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a164af526b6323a2523f4c28b09408758a114865483a9b02c56d500068ffe97
+size 3455262
diff --git a/code/pytorchvideo/.github/media/logo_horizontal_color.png b/code/pytorchvideo/.github/media/logo_horizontal_color.png
new file mode 100644
index 0000000000000000000000000000000000000000..bcb951870adaad2228b8b7bf8f26b7d8dd635b3b
Binary files /dev/null and b/code/pytorchvideo/.github/media/logo_horizontal_color.png differ
diff --git a/code/pytorchvideo/.gitignore b/code/pytorchvideo/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..776740c4b1b0e28142e31062ac1fe3245174dd7f
--- /dev/null
+++ b/code/pytorchvideo/.gitignore
@@ -0,0 +1,34 @@
+*.DS_Store
+
+build/
+_ext
+*.pyc
+*.pyd
+*.so
+*.dll
+*.egg-info/
+**/__pycache__/
+*-checkpoint.ipynb
+**/.ipynb_checkpoints
+**/.ipynb_checkpoints/**
+
+
+# Docusaurus site
+website/yarn.lock
+website/build/
+website/i18n/
+website/node_modules/*
+website/npm-debug.log
+
+## Generated for tutorials
+website/_tutorials/
+website/static/files/
+website/pages/tutorials/*
+!website/pages/tutorials/index.js
+
+
+## Conda and pip builds
+packaging/out/
+packaging/output_files/
+dist/
+wheels/
diff --git a/code/pytorchvideo/.readthedocs.yml b/code/pytorchvideo/.readthedocs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d27f49936d6674b58454fddba2544210a7f30d33
--- /dev/null
+++ b/code/pytorchvideo/.readthedocs.yml
@@ -0,0 +1,25 @@
+# .readthedocs.yml
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+# Required
+version: 2
+
+# Build documentation in the docs/ directory with Sphinx
+sphinx:
+ builder: html
+ configuration: docs/source/conf.py
+
+# Build documentation with MkDocs
+#mkdocs:
+# configuration: mkdocs.yml
+
+# Optionally build your docs in additional formats such as PDF and ePub
+formats: all
+
+# Optionally set the version of Python and requirements required to build your docs
+python:
+ version: 3.7
+ system_packages: true
+ install:
+ - requirements: docs/requirements.txt
diff --git a/code/pytorchvideo/CONTRIBUTING.md b/code/pytorchvideo/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..23e2943257ed76d661f1be066dc687a51a38c1b2
--- /dev/null
+++ b/code/pytorchvideo/CONTRIBUTING.md
@@ -0,0 +1,41 @@
+# Contributing to fvcore
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+We actively welcome your pull requests.
+
+1. Fork the repo and create your branch from `main`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Testing
+
+Please follow the instructions mentioned in [test-README](https://github.com/facebookresearch/pytorchvideo/blob/main/tests/README.md) to run the existing and your newly added tests.
+
+## Linting
+
+We provide a linting script to correctly format your code changes.
+Please follow the instructions mentioned in [dev-README](https://github.com/facebookresearch/pytorchvideo/blob/main/dev/README.md) to run the linter.
+
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Facebook's open source projects.
+
+Complete your CLA here:
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## License
+By contributing to fvcore, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree.
diff --git a/code/pytorchvideo/INSTALL.md b/code/pytorchvideo/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..16a9e4a16693b27ec46a33553250ecf1994bf9c4
--- /dev/null
+++ b/code/pytorchvideo/INSTALL.md
@@ -0,0 +1,68 @@
+# Installation
+
+## Installing PytorchVideo
+
+
+### 1. Install from PyPI
+For stable release,
+```
+pip install pytorchvideo
+=======
+conda create -n pytorchvideo python=3.7
+conda activate pytorchvideo
+conda install -c pytorch pytorch=1.8.0 torchvision cudatoolkit=10.2
+conda install -c conda-forge -c fvcore -c iopath fvcore=0.1.4 iopath
+```
+
+For nightly builds,
+```
+pip install pytorchvideo-nightly
+```
+
+### 2. Install from GitHub using pip
+```
+pip install "git+https://github.com/facebookresearch/pytorchvideo.git"
+```
+To install using the code of the released version instead of from the main branch, use the following instead.
+```
+pip install "git+https://github.com/facebookresearch/pytorchvideo.git@stable"
+```
+
+### 3. Install from a local clone
+```
+git clone https://github.com/facebookresearch/pytorchvideo.git
+cd pytorchvideo
+pip install -e .
+
+# For developing and testing
+pip install -e . [test,dev]
+```
+
+
+## Requirements
+
+### Core library
+
+- Python 3.7 or 3.8
+- PyTorch 1.8.0 or higher.
+- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
+- [fvcore](https://github.com/facebookresearch/fvcore) version 0.1.4 or higher
+- [ioPath](https://github.com/facebookresearch/iopath)
+- If CUDA is to be used, use a version which is supported by the corresponding pytorch version and at least version 10.2 or higher.
+
+We recommend setting up a conda environment with Pytorch and Torchvision before installing PyTorchVideo.
+For instance, follow the bellow instructions to setup the conda environment,
+```
+conda create -n pytorchvideo python=3.7
+conda activate pytorchvideo
+conda install -c pytorch pytorch=1.8.0 torchvision cudatoolkit=10.2
+```
+
+## Testing
+
+Please follow the instructions mentioned in [test-README](https://github.com/facebookresearch/pytorchvideo/blob/main/tests/README.md) to run the provided tests.
+
+## Linting
+
+We also provide a linting script to correctly format your code edits.
+Please follow the instructions mentioned in [dev-README](https://github.com/facebookresearch/pytorchvideo/blob/main/dev/README.md) to run the linter.
diff --git a/code/pytorchvideo/LICENSE b/code/pytorchvideo/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5a90478a33957a01c8b1c16fb4d9bf8d0687affd
--- /dev/null
+++ b/code/pytorchvideo/LICENSE
@@ -0,0 +1,201 @@
+Apache License
+Version 2.0, January 2004
+http://www.apache.org/licenses/
+
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+1. Definitions.
+
+"License" shall mean the terms and conditions for use, reproduction,
+and distribution as defined by Sections 1 through 9 of this document.
+
+"Licensor" shall mean the copyright owner or entity authorized by
+the copyright owner that is granting the License.
+
+"Legal Entity" shall mean the union of the acting entity and all
+other entities that control, are controlled by, or are under common
+control with that entity. For the purposes of this definition,
+"control" means (i) the power, direct or indirect, to cause the
+direction or management of such entity, whether by contract or
+otherwise, or (ii) ownership of fifty percent (50%) or more of the
+outstanding shares, or (iii) beneficial ownership of such entity.
+
+"You" (or "Your") shall mean an individual or Legal Entity
+exercising permissions granted by this License.
+
+"Source" form shall mean the preferred form for making modifications,
+including but not limited to software source code, documentation
+source, and configuration files.
+
+"Object" form shall mean any form resulting from mechanical
+transformation or translation of a Source form, including but
+not limited to compiled object code, generated documentation,
+and conversions to other media types.
+
+"Work" shall mean the work of authorship, whether in Source or
+Object form, made available under the License, as indicated by a
+copyright notice that is included in or attached to the work
+(an example is provided in the Appendix below).
+
+"Derivative Works" shall mean any work, whether in Source or Object
+form, that is based on (or derived from) the Work and for which the
+editorial revisions, annotations, elaborations, or other modifications
+represent, as a whole, an original work of authorship. For the purposes
+of this License, Derivative Works shall not include works that remain
+separable from, or merely link (or bind by name) to the interfaces of,
+the Work and Derivative Works thereof.
+
+"Contribution" shall mean any work of authorship, including
+the original version of the Work and any modifications or additions
+to that Work or Derivative Works thereof, that is intentionally
+submitted to Licensor for inclusion in the Work by the copyright owner
+or by an individual or Legal Entity authorized to submit on behalf of
+the copyright owner. For the purposes of this definition, "submitted"
+means any form of electronic, verbal, or written communication sent
+to the Licensor or its representatives, including but not limited to
+communication on electronic mailing lists, source code control systems,
+and issue tracking systems that are managed by, or on behalf of, the
+Licensor for the purpose of discussing and improving the Work, but
+excluding communication that is conspicuously marked or otherwise
+designated in writing by the copyright owner as "Not a Contribution."
+
+"Contributor" shall mean Licensor and any individual or Legal Entity
+on behalf of whom a Contribution has been received by Licensor and
+subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of
+this License, each Contributor hereby grants to You a perpetual,
+worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+copyright license to reproduce, prepare Derivative Works of,
+publicly display, publicly perform, sublicense, and distribute the
+Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of
+this License, each Contributor hereby grants to You a perpetual,
+worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+(except as stated in this section) patent license to make, have made,
+use, offer to sell, sell, import, and otherwise transfer the Work,
+where such license applies only to those patent claims licensable
+by such Contributor that are necessarily infringed by their
+Contribution(s) alone or by combination of their Contribution(s)
+with the Work to which such Contribution(s) was submitted. If You
+institute patent litigation against any entity (including a
+cross-claim or counterclaim in a lawsuit) alleging that the Work
+or a Contribution incorporated within the Work constitutes direct
+or contributory patent infringement, then any patent licenses
+granted to You under this License for that Work shall terminate
+as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the
+Work or Derivative Works thereof in any medium, with or without
+modifications, and in Source or Object form, provided that You
+meet the following conditions:
+
+(a) You must give any other recipients of the Work or
+Derivative Works a copy of this License; and
+
+(b) You must cause any modified files to carry prominent notices
+stating that You changed the files; and
+
+(c) You must retain, in the Source form of any Derivative Works
+that You distribute, all copyright, patent, trademark, and
+attribution notices from the Source form of the Work,
+excluding those notices that do not pertain to any part of
+the Derivative Works; and
+
+(d) If the Work includes a "NOTICE" text file as part of its
+distribution, then any Derivative Works that You distribute must
+include a readable copy of the attribution notices contained
+within such NOTICE file, excluding those notices that do not
+pertain to any part of the Derivative Works, in at least one
+of the following places: within a NOTICE text file distributed
+as part of the Derivative Works; within the Source form or
+documentation, if provided along with the Derivative Works; or,
+within a display generated by the Derivative Works, if and
+wherever such third-party notices normally appear. The contents
+of the NOTICE file are for informational purposes only and
+do not modify the License. You may add Your own attribution
+notices within Derivative Works that You distribute, alongside
+or as an addendum to the NOTICE text from the Work, provided
+that such additional attribution notices cannot be construed
+as modifying the License.
+
+You may add Your own copyright statement to Your modifications and
+may provide additional or different license terms and conditions
+for use, reproduction, or distribution of Your modifications, or
+for any such Derivative Works as a whole, provided Your use,
+reproduction, and distribution of the Work otherwise complies with
+the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise,
+any Contribution intentionally submitted for inclusion in the Work
+by You to the Licensor shall be under the terms and conditions of
+this License, without any additional terms or conditions.
+Notwithstanding the above, nothing herein shall supersede or modify
+the terms of any separate license agreement you may have executed
+with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade
+names, trademarks, service marks, or product names of the Licensor,
+except as required for reasonable and customary use in describing the
+origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or
+agreed to in writing, Licensor provides the Work (and each
+Contributor provides its Contributions) on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+implied, including, without limitation, any warranties or conditions
+of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+PARTICULAR PURPOSE. You are solely responsible for determining the
+appropriateness of using or redistributing the Work and assume any
+risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory,
+whether in tort (including negligence), contract, or otherwise,
+unless required by applicable law (such as deliberate and grossly
+negligent acts) or agreed to in writing, shall any Contributor be
+liable to You for damages, including any direct, indirect, special,
+incidental, or consequential damages of any character arising as a
+result of this License or out of the use or inability to use the
+Work (including but not limited to damages for loss of goodwill,
+work stoppage, computer failure or malfunction, or any and all
+other commercial damages or losses), even if such Contributor
+has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing
+the Work or Derivative Works thereof, You may choose to offer,
+and charge a fee for, acceptance of support, warranty, indemnity,
+or other liability obligations and/or rights consistent with this
+License. However, in accepting such obligations, You may act only
+on Your own behalf and on Your sole responsibility, not on behalf
+of any other Contributor, and only if You agree to indemnify,
+defend, and hold each Contributor harmless for any liability
+incurred by, or claims asserted against, such Contributor by reason
+of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
+
+APPENDIX: How to apply the Apache License to your work.
+
+To apply the Apache License to your work, attach the following
+boilerplate notice, with the fields enclosed by brackets "[]"
+replaced with your own identifying information. (Don't include
+the brackets!) The text should be enclosed in the appropriate
+comment syntax for the file format. We also recommend that a
+file or class name and description of purpose be included on the
+same "printed page" as the copyright notice for easier
+identification within third-party archives.
+
+Copyright 2019, Facebook, Inc
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/code/pytorchvideo/MANIFEST.in b/code/pytorchvideo/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..538a8f8e30199d1313436c54680fd6a18a53900b
--- /dev/null
+++ b/code/pytorchvideo/MANIFEST.in
@@ -0,0 +1,3 @@
+include LICENSE
+include CONTRIBUTING.md
+include requirements.txt
\ No newline at end of file
diff --git a/code/pytorchvideo/README.md b/code/pytorchvideo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4434abc75937c24cd7e5e6bdb98dc90b4090d20d
--- /dev/null
+++ b/code/pytorchvideo/README.md
@@ -0,0 +1,94 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ A deep learning library for video understanding research.
+
+
+ Check the website for more information.
+
+
+
+| |
+|:-------------------------------:|:--------------------------------------------------:|
+| A PyTorchVideo-accelerated X3D model running on a Samsung Galaxy S10 phone. The model runs ~8x faster than real time, requiring roughly 130 ms to process one second of video.| A PyTorchVideo-based SlowFast model performing video action detection.|
+
+## X3D model Web Demo
+Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [](https://huggingface.co/spaces/pytorch/X3D)
+
+## Introduction
+
+PyTorchVideo is a deeplearning library with a focus on video understanding work. PytorchVideo provides reusable, modular and efficient components needed to accelerate the video understanding research. PyTorchVideo is developed using [PyTorch](https://pytorch.org) and supports different deeplearning video components like video models, video datasets, and video-specific transforms.
+
+Key features include:
+
+- **Based on PyTorch:** Built using PyTorch. Makes it easy to use all of the PyTorch-ecosystem components.
+- **Reproducible Model Zoo:** Variety of state of the art pretrained video models and their associated benchmarks that are ready to use.
+ Complementing the model zoo, PyTorchVideo comes with extensive data loaders supporting different datasets.
+- **Efficient Video Components:** Video-focused fast and efficient components that are easy to use. Supports accelerated inference on hardware.
+
+## Updates
+
+- Aug 2021: [Multiscale Vision Transformers](https://arxiv.org/abs/2104.11227) has been released in PyTorchVideo, details can be found from [here](https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/models/vision_transformers.py#L97).
+
+## Installation
+
+Install PyTorchVideo inside a conda environment(Python >=3.7) with
+```shell
+pip install pytorchvideo
+```
+
+For detailed instructions please refer to [INSTALL.md](INSTALL.md).
+
+## License
+
+PyTorchVideo is released under the [Apache 2.0 License](LICENSE).
+
+## Tutorials
+
+Get started with PyTorchVideo by trying out one of our [tutorials](https://pytorchvideo.org/docs/tutorial_overview) or by running examples in the [tutorials folder](./tutorials).
+
+
+## Model Zoo and Baselines
+We provide a large set of baseline results and trained models available for download in the [PyTorchVideo Model Zoo](https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md).
+
+## Contributors
+
+Here is the growing list of PyTorchVideo contributors in alphabetical order (let us know if you would like to be added):
+[Aaron Adcock](https://www.linkedin.com/in/aaron-adcock-79855383/), [Amy Bearman](https://www.linkedin.com/in/amy-bearman/), [Bernard Nguyen](https://www.linkedin.com/in/mrbernardnguyen/), [Bo Xiong](https://www.cs.utexas.edu/~bxiong/), [Chengyuan Yan](https://www.linkedin.com/in/chengyuan-yan-4a804282/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/), [Dave Schnizlein](https://www.linkedin.com/in/david-schnizlein-96020136/), [Haoqi Fan](https://haoqifan.github.io/), [Heng Wang](https://hengcv.github.io/), [Jackson Hamburger](https://www.linkedin.com/in/jackson-hamburger-986a2873/), [Jitendra Malik](http://people.eecs.berkeley.edu/~malik/), [Kalyan Vasudev Alwala](https://www.linkedin.com/in/kalyan-vasudev-alwala-2a802b64/), [Matt Feiszli](https://www.linkedin.com/in/matt-feiszli-76b34b/), [Nikhila Ravi](https://www.linkedin.com/in/nikhilaravi/), [Ross Girshick](https://www.rossgirshick.info/), [Tullie Murrell](https://www.linkedin.com/in/tullie/), [Wan-Yen Lo](https://www.linkedin.com/in/wanyenlo/), [Weiyao Wang](https://www.linkedin.com/in/weiyaowang/?locale=en_US), [Xiaowen Lin](https://www.linkedin.com/in/xiaowen-lin-90542b34/), [Yanghao Li](https://lyttonhao.github.io/), [Yilei Li](https://liyilui.github.io/personal_page/), [Zhengxing Chen](http://czxttkl.github.io/), [Zhicheng Yan](https://www.linkedin.com/in/zhichengyan/).
+
+
+## Development
+
+We welcome new contributions to PyTorchVideo and we will be actively maintaining this library! Please refer to [`CONTRIBUTING.md`](./.github/CONTRIBUTING.md) for full instructions on how to run the code, tests and linter, and submit your pull requests.
+
+## Citing PyTorchVideo
+
+If you find PyTorchVideo useful in your work, please use the following BibTeX entry for citation.
+```BibTeX
+@inproceedings{fan2021pytorchvideo,
+ author = {Haoqi Fan and Tullie Murrell and Heng Wang and Kalyan Vasudev Alwala and Yanghao Li and Yilei Li and Bo Xiong and Nikhila Ravi and Meng Li and Haichuan Yang and Jitendra Malik and Ross Girshick and Matt Feiszli and Aaron Adcock and Wan-Yen Lo and Christoph Feichtenhofer},
+ title = {{PyTorchVideo}: A Deep Learning Library for Video Understanding},
+ booktitle = {Proceedings of the 29th ACM International Conference on Multimedia},
+ year = {2021},
+ note = {\url{https://pytorchvideo.org/}},
+}
+```
diff --git a/code/pytorchvideo/dev/README.md b/code/pytorchvideo/dev/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..027eeae3b5b36caf44af720219d2f090c3c16875
--- /dev/null
+++ b/code/pytorchvideo/dev/README.md
@@ -0,0 +1,11 @@
+## Running Linter
+
+
+Before running the linter, please ensure that you installed the necessary additional linter dependencies.
+If not installed, check the [install-README](https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md) on how to do it.
+
+Post that, you can run the linter from the project root using,
+
+```
+./dev/linter.sh
+```
diff --git a/code/pytorchvideo/dev/linter.sh b/code/pytorchvideo/dev/linter.sh
new file mode 100644
index 0000000000000000000000000000000000000000..eafbac0981be67410c21758eb8af8f72cf1214c5
--- /dev/null
+++ b/code/pytorchvideo/dev/linter.sh
@@ -0,0 +1,25 @@
+#!/bin/bash -ev
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+# Run this script at project root with "./dev/linter.sh" before you commit.
+
+echo "Running autoflake..."
+python -m autoflake --remove-all-unused-imports -i .
+
+echo "Running isort..."
+isort -y -sp .
+
+echo "Running black..."
+black .
+
+echo "Running flake8..."
+if [ -x "$(command -v flake8)" ]; then
+ flake8 .
+else
+ python3 -m flake8 .
+fi
+
+command -v arc > /dev/null && {
+ echo "Running arc lint ..."
+ arc lint
+}
diff --git a/code/pytorchvideo/docs/Makefile b/code/pytorchvideo/docs/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293
--- /dev/null
+++ b/code/pytorchvideo/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/code/pytorchvideo/docs/README.md b/code/pytorchvideo/docs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2fb037e2d83ef4e8e54414d44f4c448b4551bf0e
--- /dev/null
+++ b/code/pytorchvideo/docs/README.md
@@ -0,0 +1,65 @@
+
+## Setup
+
+### Install dependencies
+
+```
+pip install -U recommonmark mock sphinx sphinx_rtd_theme sphinx_markdown_tables
+```
+
+### Add symlink to the root README.md
+
+We want to include the root readme as an overview. Before generating the docs create a symlink to the root readme.
+
+```
+cd /docs
+ln -s ../README.md overview.md
+```
+
+In `conf.py` for deployment this is done using `subprocess.call`.
+
+### Add a new file
+
+Add a new `.md` or `.rst` file and add the name to the doc tree in `index.rst` e.g
+
+```
+.. toctree::
+ :maxdepth: 1
+ :caption: Intro Documentation
+
+ overview
+```
+
+### Build
+
+From `pytorchvideo/docs` run:
+
+```
+> make html
+```
+
+The website is generated in `build/html`.
+
+### Common Issues
+
+Sphinx can be fussy, and sometimes about things you weren’t expecting. For example, you might encounter something like:
+
+WARNING: toctree contains reference to nonexisting document u'overview'
+...
+checking consistency...
+/docs/overview.rst::
+WARNING: document isn't included in any toctree
+
+You might have indented overview in the .. toctree:: in index.rst with four spaces, when Sphinx is expecting three.
+
+
+### View
+
+Start a python simple server:
+
+```
+> python -m http.server
+```
+
+Navigate to: `http://0.0.0.0:8000/`
+
diff --git a/code/pytorchvideo/docs/make.bat b/code/pytorchvideo/docs/make.bat
new file mode 100644
index 0000000000000000000000000000000000000000..9534b018135ed7d5caed6298980c55e8b1d2ec82
--- /dev/null
+++ b/code/pytorchvideo/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/code/pytorchvideo/docs/requirements.txt b/code/pytorchvideo/docs/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..95e898e22933cf92dc195573361e860e2a5e9074
--- /dev/null
+++ b/code/pytorchvideo/docs/requirements.txt
@@ -0,0 +1,15 @@
+docutils==0.16
+# https://github.com/sphinx-doc/sphinx/commit/7acd3ada3f38076af7b2b5c9f3b60bb9c2587a3d
+sphinx==3.2.0
+recommonmark==0.6.0
+sphinx_markdown_tables
+mock
+numpy
+av
+torch
+torchvision
+opencv-python
+parameterized
+git+git://github.com/facebookresearch/fvcore.git
+git+git://github.com/facebookresearch/iopath.git
+git+git://github.com/kalyanvasudev/pytorch_sphinx_theme.git
diff --git a/code/pytorchvideo/docs/source/accelerator.md b/code/pytorchvideo/docs/source/accelerator.md
new file mode 100644
index 0000000000000000000000000000000000000000..8f948ac646b860194184e5ea3d17d9ad284059bc
--- /dev/null
+++ b/code/pytorchvideo/docs/source/accelerator.md
@@ -0,0 +1,60 @@
+
+# Overview
+
+Our vision for PyTorchVideo/Accelerator is to enable video understanding models to run efficiently on all tiers of hardware devices, from mobile phone to GPU. PyTorchVideo/Accelerator (Accelerator) is aimed to accelerate the speed of video understanding model running on various hardware devices, as well as the whole process of design and deploy hardware-aware efficient video understanding models. Specifically, Accelerator provides a complete environment which allows users to:
+
+* Design efficient models for target hardware with carefully tuned efficient blocks;
+* Fine tune efficient model from Model Zoo;
+* Optimize model kernel and graph for target device;
+* Deploy efficient model to target device.
+
+
+We benchmarked the latency of SOTA models ([X3D-XS and X3D-S](https://arxiv.org/abs/2004.04730)) on a mainstream mobile device (Samsung S9 International, released in 2018). With Accelerator, we not only observed 4-6X latency reduction on fp32, but also enabled int8 operation which has not been supported in vanilla Pytorch. A table summarizing latency comparison is shown below.
+
+|model |implementation |precision |latency per 1-s clip (ms) |speed up |
+|--- |------------------------- |--- |--- |--- |
+|X3D-XS |Vanilla Pytorch |fp32 |1067 |1.0X |
+|X3D-XS |PytrochVideo/ Accelerator |fp32 |233 |4.6X |
+|X3D-XS |PytrochVideo/ Accelerator |int8 |165 |6.5X |
+|X3D-S |Vanilla Pytorch |fp32 |4248 |1.0X |
+|X3D-S |PytrochVideo/ Accelerator |fp32 |763 |5.6X |
+|X3D-S |PytrochVideo/ Accelerator |int8 |503 |8.4X |
+
+## Components in PyTorchVideo/Accelerator
+
+### Efficient block library
+
+Efficient block library contains common building blocks (residual block, squeeze-excite, etc.) that can be mapped to high-performance kernel operator implementation library of target device platform. The rationale behind having an efficient block library is that high-performance kernel operator library generally only supports a small set of kernel operators. In other words, a randomly picked kernel might not be supported by high-performance kernel operator library. By having an efficient block library and building model using efficient blocks in that library can guarantee the model is deployable with high efficiency on target device.
+
+Efficient block library lives under `pytorchvideo/layers/accelerator/` (for simple layers) and `pytorchvideo/models/accelerator/` (for complex modules such as residual block). Please also check [Build your model with PyTorchVideo/Accelerator](https://pytorchvideo.org/docs/tutorial_accelerator_build_your_model) tutorial for detailed examples.
+
+### Deployment
+
+Deployment flow includes kernel optimization as well as model export for target backend. Kernel optimization utilities can be an extremely important part that decides performance of on-device model operation. Accelerator provides a bunch of useful utilities for deployment under `pytorchvideo/accelerator/deployment`. Please also check related tutorials ([Build your model with PyTorchVideo/Accelerator](https://pytorchvideo.org/docs/tutorial_accelerator_build_your_model), [Accelerate your model with model transmuter in PyTorchVideo/Accelerator](https://pytorchvideo.org/docs/tutorial_accelerator_use_model_transmuter)) for detailed examples.
+
+### Model zoo
+
+Accelerator provides efficient model zoo for target devices, which include model builder (under `pytorchvideo/models/accelerator/`) as well as pretrained checkpoint. Please also refer to [Use PyTorchVideo/Accelerator Model Zoo](https://pytorchvideo.org/docs/tutorial_accelerator_use_accelerator_model_zoo) for how to use model zoo.
+
+
+## Supported devices
+
+Currently mobile cpu (ARM-based cpu on mobile phones) is supported. We will update this page once more target devices are supported.
+
+## Demo
+
+Checkout our on-device video classification demos running on mobile phone!
+
+[Android demo](https://github.com/pytorch/android-demo-app/tree/master/TorchVideo)
+
+[iOS demo](https://github.com/pytorch/ios-demo-app/tree/master/TorchVideo)
+
+## Jumpstart
+
+Refer to following tutorial pages to get started!
+
+[Build your model with PyTorchVideo/Accelerator](https://pytorchvideo.org/docs/tutorial_accelerator_build_your_model)
+
+[Use PyTorchVideo/Accelerator Model Zoo](https://pytorchvideo.org/docs/tutorial_accelerator_use_accelerator_model_zoo)
+
+[Accelerate your model with model transmuter in PyTorchVideo/Accelerator](https://pytorchvideo.org/docs/tutorial_accelerator_use_model_transmuter)
diff --git a/code/pytorchvideo/docs/source/api/data/data.rst b/code/pytorchvideo/docs/source/api/data/data.rst
new file mode 100644
index 0000000000000000000000000000000000000000..4c8784600cca74f64bd46681d186e3a616c4e562
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/data/data.rst
@@ -0,0 +1,8 @@
+pytorchvideo.data
+=================
+
+.. automodule:: pytorchvideo.data
+ :imported-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/code/pytorchvideo/docs/source/api/data/index.rst b/code/pytorchvideo/docs/source/api/data/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..7d5781ea587ad18c01302923271a952ea25c286d
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/data/index.rst
@@ -0,0 +1,7 @@
+Data API
+==================
+
+.. toctree::
+
+ data
+
diff --git a/code/pytorchvideo/docs/source/api/index.rst b/code/pytorchvideo/docs/source/api/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..ef7efd8f520c0fe6a52921c577550318dd1d7b24
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/index.rst
@@ -0,0 +1,9 @@
+API Documentation
+==================
+
+.. toctree::
+
+ models/index
+ data/index
+ layers/index
+ transforms/index
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/layers/index.rst b/code/pytorchvideo/docs/source/api/layers/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..31677003b05386c8c41d614fc93fedb546d1ac8e
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/layers/index.rst
@@ -0,0 +1,6 @@
+Layers API
+==================
+
+.. toctree::
+
+ layers
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/layers/layers.rst b/code/pytorchvideo/docs/source/api/layers/layers.rst
new file mode 100644
index 0000000000000000000000000000000000000000..c988d0583f1ca8349f8e45356b76d10bfc0552f3
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/layers/layers.rst
@@ -0,0 +1,56 @@
+pytorchvideo.layers.batch_norm
+=================================
+
+
+.. automodule:: pytorchvideo.layers.batch_norm
+ :members:
+
+
+pytorchvideo.layers.convolutions
+=================================
+
+
+.. automodule:: pytorchvideo.layers.convolutions
+ :members:
+
+pytorchvideo.layers.fusion
+=================================
+
+
+.. automodule:: pytorchvideo.layers.fusion
+ :members:
+
+pytorchvideo.layers.mlp
+=================================
+
+
+.. automodule:: pytorchvideo.layers.mlp
+ :members:
+
+pytorchvideo.layers.nonlocal_net
+=================================
+
+
+.. automodule:: pytorchvideo.layers.nonlocal_net
+ :members:
+
+pytorchvideo.layers.positional_encoding
+=================================
+
+
+.. automodule:: pytorchvideo.layers.positional_encoding
+ :members:
+
+pytorchvideo.layers.swish
+=================================
+
+
+.. automodule:: pytorchvideo.layers.swish
+ :members:
+
+pytorchvideo.layers.squeeze_excitation
+=================================
+
+
+.. automodule:: pytorchvideo.layers.squeeze_excitation
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/byol.rst b/code/pytorchvideo/docs/source/api/models/byol.rst
new file mode 100644
index 0000000000000000000000000000000000000000..1337b5dbe26d01d16931b593582c6d628fa8896d
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/byol.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.byol
+=================================
+
+
+.. automodule:: pytorchvideo.models.byol
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/csn.rst b/code/pytorchvideo/docs/source/api/models/csn.rst
new file mode 100644
index 0000000000000000000000000000000000000000..a4880f292949296b67a1e90c167f116aa41a33b1
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/csn.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.csn
+=================================
+
+
+.. automodule:: pytorchvideo.models.csn
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/head.rst b/code/pytorchvideo/docs/source/api/models/head.rst
new file mode 100644
index 0000000000000000000000000000000000000000..46dafcf1000e6f76077d686b219c1063772c9d7b
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/head.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.head
+=================================
+
+
+.. automodule:: pytorchvideo.models.head
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/index.rst b/code/pytorchvideo/docs/source/api/models/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..14a758d0ef94cf46731c513c384982fa5c9df78f
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/index.rst
@@ -0,0 +1,17 @@
+Models API
+==================
+
+.. toctree::
+
+ resnet
+ net
+ head
+ stem
+ csn
+ x3d
+ slowfast
+ r2plus1d
+ simclr
+ byol
+ memory_bank
+ masked_multistream
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/masked_multistream.rst b/code/pytorchvideo/docs/source/api/models/masked_multistream.rst
new file mode 100644
index 0000000000000000000000000000000000000000..6a32afa17fe94364b78d0f53b18e64555b9416d1
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/masked_multistream.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.masked_multistream
+=================================
+
+
+.. automodule:: pytorchvideo.models.masked_multistream
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/memory_bank.rst b/code/pytorchvideo/docs/source/api/models/memory_bank.rst
new file mode 100644
index 0000000000000000000000000000000000000000..3334d47c67b4c1cbe821ddcb89699d63a91006ce
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/memory_bank.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.memory_bank
+=================================
+
+
+.. automodule:: pytorchvideo.models.memory_bank
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/net.rst b/code/pytorchvideo/docs/source/api/models/net.rst
new file mode 100644
index 0000000000000000000000000000000000000000..4ea90a7ccbbdb5b8db68eca43eda222998a9016a
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/net.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.net
+=================================
+
+
+.. automodule:: pytorchvideo.models.net
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/r2plus1d.rst b/code/pytorchvideo/docs/source/api/models/r2plus1d.rst
new file mode 100644
index 0000000000000000000000000000000000000000..377302f7f4a1d49e51e0664914e1dc4d5ef805e0
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/r2plus1d.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.r2plus1d
+=================================
+
+
+.. automodule:: pytorchvideo.models.r2plus1d
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/resnet.rst b/code/pytorchvideo/docs/source/api/models/resnet.rst
new file mode 100644
index 0000000000000000000000000000000000000000..0570a2187d0ba34d3225fd8877390576f7536aaf
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/resnet.rst
@@ -0,0 +1,7 @@
+pytorchvideo.models.resnet
+=================================
+
+Building blocks for Resnet and resnet-like models
+
+.. automodule:: pytorchvideo.models.resnet
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/simclr.rst b/code/pytorchvideo/docs/source/api/models/simclr.rst
new file mode 100644
index 0000000000000000000000000000000000000000..a34ff7ccfe52298b2df0d03ea60a56135d7254ac
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/simclr.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.simclr
+=================================
+
+
+.. automodule:: pytorchvideo.models.simclr
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/slowfast.rst b/code/pytorchvideo/docs/source/api/models/slowfast.rst
new file mode 100644
index 0000000000000000000000000000000000000000..1bed28adf31b20c463b139b1c5f90b39d207e4d9
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/slowfast.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.slowfast
+=================================
+
+
+.. automodule:: pytorchvideo.models.slowfast
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/stem.rst b/code/pytorchvideo/docs/source/api/models/stem.rst
new file mode 100644
index 0000000000000000000000000000000000000000..fbc17c7bbd78a68878885037fadf1ef734f9f24c
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/stem.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.stem
+=================================
+
+
+.. automodule:: pytorchvideo.models.stem
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/models/x3d.rst b/code/pytorchvideo/docs/source/api/models/x3d.rst
new file mode 100644
index 0000000000000000000000000000000000000000..fbe6814315ca35e4cb9224ba70e4ad6b76ac0152
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/models/x3d.rst
@@ -0,0 +1,6 @@
+pytorchvideo.models.x3d
+=================================
+
+
+.. automodule:: pytorchvideo.models.x3d
+ :members:
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/transforms/index.rst b/code/pytorchvideo/docs/source/api/transforms/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..e009cef528bb3fa3e4178411dab86856a1b3cfd8
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/transforms/index.rst
@@ -0,0 +1,6 @@
+Transforms API
+==================
+
+.. toctree::
+
+ transforms
\ No newline at end of file
diff --git a/code/pytorchvideo/docs/source/api/transforms/transforms.rst b/code/pytorchvideo/docs/source/api/transforms/transforms.rst
new file mode 100644
index 0000000000000000000000000000000000000000..9ca47842b611eb8a1fa6851b7e90216b70c2c5b4
--- /dev/null
+++ b/code/pytorchvideo/docs/source/api/transforms/transforms.rst
@@ -0,0 +1,20 @@
+pytorchvideo.transforms
+==================================
+
+
+.. automodule:: pytorchvideo.transforms
+ :imported-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+pytorchvideo.transforms.functional
+==================================
+
+
+.. automodule:: pytorchvideo.transforms.functional
+ :imported-members:
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/code/pytorchvideo/docs/source/conf.py b/code/pytorchvideo/docs/source/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..692755736f7fe198ea54aad5970d78c3c041e6de
--- /dev/null
+++ b/code/pytorchvideo/docs/source/conf.py
@@ -0,0 +1,190 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+# flake8: noqa
+import os
+import sys
+
+import unittest.mock as mock
+
+# -- Project information -----------------------------------------------------
+import pytorch_sphinx_theme
+from recommonmark.parser import CommonMarkParser
+from recommonmark.transform import AutoStructify
+
+
+# -- Path setup --------------------------------------------------------------
+sys.path.insert(0, os.path.abspath("../"))
+sys.path.insert(0, os.path.abspath("../pytorchvideo"))
+sys.path.insert(0, os.path.abspath("../../"))
+
+
+# The full version, including alpha/beta/rc tags
+try:
+ import torch # noqa
+except ImportError:
+ for m in [
+ "torch",
+ "torchvision",
+ "torch.nn",
+ "torch.autograd",
+ "torch.autograd.function",
+ "torch.nn.modules",
+ "torch.nn.modules.utils",
+ "torch.utils",
+ "torch.utils.data",
+ "torchvision",
+ "torchvision.ops",
+ "torchvision.datasets",
+ "torchvision.datasets.folder",
+ "torch.utils.data.IterableDataset",
+ ]:
+ sys.modules[m] = mock.Mock(name=m)
+
+
+project = "PyTorchVideo"
+copyright = "2021, PyTorchVideo contributors"
+author = "PyTorchVideo contributors"
+
+
+# -- General configuration ---------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#
+needs_sphinx = "3.0"
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "recommonmark",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.todo",
+ "sphinx.ext.coverage",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.githubpages",
+ "sphinx.ext.doctest",
+ "sphinx.ext.ifconfig",
+ "sphinx_markdown_tables",
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# -- Configurations for plugins ------------
+napoleon_google_docstring = True
+napoleon_include_init_with_doc = True
+napoleon_include_special_with_doc = True
+napoleon_numpy_docstring = False
+napoleon_use_rtype = False
+autodoc_inherit_docstrings = False
+autodoc_member_order = "bysource"
+
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3.6", None),
+ "numpy": ("https://docs.scipy.org/doc/numpy/", None),
+ "torch": ("https://pytorch.org/docs/master/", None),
+}
+# -------------------------
+
+source_parsers = {".md": CommonMarkParser}
+
+# Add any paths that contain templates here, relative to this directory.
+# templates_path = ["_templates"]
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+source_suffix = [".rst", ".md"]
+
+# The master toctree document.
+master_doc = "index"
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+autodoc_typehints = "description"
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "build", "README.md"]
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = "sphinx"
+
+# If true, `todo` and `todoList` produce output, else they produce nothing.
+todo_include_todos = True
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = "pytorch_sphinx_theme"
+html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
+
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+
+html_theme_options = {
+ "includehidden": False,
+ "canonical_url": "https://pytorchvideo.org/api/",
+ "pytorch_project": "docs",
+}
+
+html_baseurl = "/"
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+
+html_favicon = "../../website/website/static/img/favicon.png"
+
+# -- Options for HTMLHelp output ------------------------------------------
+
+# Output file base name for HTML help builder.
+htmlhelp_basename = "pytorchvideodoc"
+
+# -- Options for manual page output ---------------------------------------
+
+# One entry per manual page. List of tuples
+# (source start file, name, description, authors, manual section).
+man_pages = [(master_doc, "pytorchvideo", "PyTorchVideo Documentation", [author], 1)]
+
+
+# -- Options for Texinfo output -------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+# dir menu entry, description, category)
+texinfo_documents = [
+ (
+ master_doc,
+ "PyTorchVideo",
+ "PyTorchVideo Documentation",
+ author,
+ "PyTorchVideo",
+ "One line description of project.",
+ "Miscellaneous",
+ )
+]
+
+github_doc_root = "https://github.com/facebookresearch/pytorchvideo/tree/main"
+
+
+def setup(app):
+ app.add_config_value(
+ "recommonmark_config",
+ {
+ "url_resolver": lambda url: github_doc_root + url,
+ "auto_toc_tree_section": "Contents",
+ },
+ True,
+ )
+ app.add_transform(AutoStructify)
diff --git a/code/pytorchvideo/docs/source/data.md b/code/pytorchvideo/docs/source/data.md
new file mode 100644
index 0000000000000000000000000000000000000000..038214fb657dceebb8f1fadcb0b3969427e456a0
--- /dev/null
+++ b/code/pytorchvideo/docs/source/data.md
@@ -0,0 +1,48 @@
+# Overview
+
+PyTorchVideo datasets are subclasses of either [```torch.utils.data.Dataset```](https://pytorch.org/docs/stable/data.html#map-style-datasets) or [```torch.utils.data.IterableDataset```](https://pytorch.org/docs/stable/data.html#iterable-style-datasets). As such, they can all be used with a [```torch.utils.data.DataLoader```](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoade), which can load multiple samples in parallel using [```torch.multiprocessing```](https://pytorch.org/docs/stable/multiprocessing.html) workers. For example:
+
+```python
+dataset = pytorchvideo.data.Kinetics(
+ data_path="path/to/kinetics_root/train.csv",
+ clip_sampler=pytorchvideo.data.make_clip_sampler("random", duration=2),
+)
+data_loader = torch.utils.data.DataLoader(dataset, batch_size=8)
+```
+
+## How do PyTorchVideo datasets work?
+
+Although there isn't a strict interface governing how PyTorchVideo datasets work, they all share a common design as follows:
+
+1. Each dataset starts by taking a list of video paths and labels in some form. For example, Kinetics can take a file with each row containing a video path and label, or a directory containing a ```\/\.mp4``` like file structure. Each respective dataset documents the exact structure it expected for the given data path.
+
+2. At each iteration a video sampler is used to determine which video-label pair is going to be sampled from the list of videos from the previous point. For some datasets this is required to be a random sampler, others reuse the [```torch.utils.data.Sampler```](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler) interface for more flexibility.
+
+3. A clip sampler is then used to determine which frames to sample from the selected video. For example, your application may want to sample 2 second clips at random for the selected video at each iteration. Some datasets like Kinetics make the most of the [```pytorchvideo.data.clip_sampling```](https://pytorchvideo.readthedocs.io/en/latest/api/data/extra.html#pytorchvideo-data-clip-sampling) interface to provide flexibility on how to define these clips. Other datasets simply require you to specify an enum for common clip sampling configurations.
+
+4. Depending on if the underlying videos are stored as either encoded videos (e.g. mp4) or frame videos (i.e. a folder of images containing each decoded frame) - the video clip is then selectively read or decoded into the canonical video tensor with shape ```(C, T, H, W)``` and audio tensor with shape ```(S)```. We provide two options for decoding: PyAv or TorchVision, which can be chosen in the interface of the datasets that supported encoded videos.
+
+5. The next step of a PyTorchVideo dataset is creating a clip dictionary containing the video modalities, label and metadata ready to be returned. An example clip dictionary might look like this:
+ ```
+ {
+ 'video': , # Shape: (C, T, H, W)
+ 'audio': , # Shape: (S)
+ 'label': , # Integer defining class annotation
+ 'video_name': , # Video file path stem
+ 'video_index': , # index of video used by sampler
+ 'clip_index': # index of clip sampled within video
+ }
+ ```
+ All datasets share the same canonical modality tensor shapes and dtypes, which aligns with tensor types of other domain specific libraries (e.g. TorchVision, TorchAudio).
+
+6. The final step before returning a clip, involves feeding it into a transform callable that can be defined for of all PyTorchVideo datasets. This callable is used to allow custom data processing or augmentations to be applied before batch collation in the [```torch.utils.data.DataLoader```](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). PyTorchVideo provides common [```pytorchvideo.transforms```](https://pytorchvideo.readthedocs.io/en/latest/transforms.html) that are useful for this callable, but users can easily define their own too.
+
+## Available datasets:
+
+* Charades
+* Domsev
+* EpicKitchen
+* HMDB51
+* Kinetics
+* SSV2
+* UCF101
diff --git a/code/pytorchvideo/docs/source/data_preparation.md b/code/pytorchvideo/docs/source/data_preparation.md
new file mode 100644
index 0000000000000000000000000000000000000000..d4756eba98b513fd6d396408a4a2db3128b959a6
--- /dev/null
+++ b/code/pytorchvideo/docs/source/data_preparation.md
@@ -0,0 +1,164 @@
+## Data Preparation
+
+### Kinetics
+
+For more information about Kinetics dataset, please refer the official [website](https://deepmind.com/research/open-source/kinetics). You can take the following steps to prepare the dataset:
+
+1. Download the videos via the official [scripts](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics).
+
+2. Preprocess the downloaded videos by resizing to the short edge size of 256.
+
+3. Prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is:
+
+```
+path_to_video_1 label_1
+path_to_video_2 label_2
+path_to_video_3 label_3
+...
+path_to_video_N label_N
+```
+
+All the Kinetics models in the Model Zoo are trained and tested with the same data as [Non-local Network](https://github.com/facebookresearch/video-nonlocal-net/blob/main/DATASET.md) and [PySlowFast](https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/DATASET.md). For dataset specific issues, please reach out to the [dataset provider](https://deepmind.com/research/open-source/kinetics).
+
+
+### Charades
+
+We follow [PySlowFast](https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/DATASET.md) to prepare the Charades dataset as follow:
+
+1. Download the Charades RGB frames from [official website](http://ai2-website.s3.amazonaws.com/data/Charades_v1_rgb.tar).
+
+2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/charades/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/charades/frame_lists/val.csv)).
+
+
+### Something-Something V2
+
+We follow [PySlowFast](https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/DATASET.md) to prepare the Something-Something V2 dataset as follow:
+
+1. Download the dataset and annotations from [official website](https://20bn.com/datasets/something-something).
+
+2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/val.csv)).
+
+3. Extract the frames from downloaded videos at 30 FPS. We used ffmpeg-4.1.3 with command:
+ ```
+ ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"
+ ```
+4. The extracted frames should be organized to be consistent with the paths in frame lists.
+
+
+### AVA (Actions V2.2)
+
+The AVA Dataset could be downloaded from the [official site](https://research.google.com/ava/download.html#ava_actions_download)
+
+We followed the same [downloading and preprocessing procedure](https://github.com/facebookresearch/video-long-term-feature-banks/blob/main/DATASET.md) as the [Long-Term Feature Banks for Detailed Video Understanding](https://arxiv.org/abs/1812.05038) do.
+
+You could follow these steps to download and preprocess the data:
+
+1. Download videos
+
+```
+DATA_DIR="../../data/ava/videos"
+
+if [[ ! -d "${DATA_DIR}" ]]; then
+ echo "${DATA_DIR} doesn't exist. Creating it.";
+ mkdir -p ${DATA_DIR}
+fi
+
+wget https://s3.amazonaws.com/ava-dataset/annotations/ava_file_names_trainval_v2.1.txt
+
+for line in $(cat ava_file_names_trainval_v2.1.txt)
+do
+ wget https://s3.amazonaws.com/ava-dataset/trainval/$line -P ${DATA_DIR}
+done
+```
+
+2. Cut each video from its 15th to 30th minute. AVA has valid annotations only in this range.
+
+```
+IN_DATA_DIR="../../data/ava/videos"
+OUT_DATA_DIR="../../data/ava/videos_15min"
+
+if [[ ! -d "${OUT_DATA_DIR}" ]]; then
+ echo "${OUT_DATA_DIR} doesn't exist. Creating it.";
+ mkdir -p ${OUT_DATA_DIR}
+fi
+
+for video in $(ls -A1 -U ${IN_DATA_DIR}/*)
+do
+ out_name="${OUT_DATA_DIR}/${video##*/}"
+ if [ ! -f "${out_name}" ]; then
+ ffmpeg -ss 900 -t 901 -i "${video}" "${out_name}"
+ fi
+done
+```
+
+3. Extract frames
+
+```
+IN_DATA_DIR="../../data/ava/videos_15min"
+OUT_DATA_DIR="../../data/ava/frames"
+
+if [[ ! -d "${OUT_DATA_DIR}" ]]; then
+ echo "${OUT_DATA_DIR} doesn't exist. Creating it.";
+ mkdir -p ${OUT_DATA_DIR}
+fi
+
+for video in $(ls -A1 -U ${IN_DATA_DIR}/*)
+do
+ video_name=${video##*/}
+
+ if [[ $video_name = *".webm" ]]; then
+ video_name=${video_name::-5}
+ else
+ video_name=${video_name::-4}
+ fi
+
+ out_video_dir=${OUT_DATA_DIR}/${video_name}/
+ mkdir -p "${out_video_dir}"
+
+ out_name="${out_video_dir}/${video_name}_%06d.jpg"
+
+ ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"
+done
+```
+
+4. Download annotations
+
+```
+DATA_DIR="../../data/ava/annotations"
+
+if [[ ! -d "${DATA_DIR}" ]]; then
+ echo "${DATA_DIR} doesn't exist. Creating it.";
+ mkdir -p ${DATA_DIR}
+fi
+
+wget https://research.google.com/ava/download/ava_v2.2.zip -P ${DATA_DIR}
+unzip -q ${DATA_DIR}/ava_v2.2.zip -d ${DATA_DIR}
+```
+
+5. Download "frame lists" ([train](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/frame_lists/val.csv)) and put them in
+the `frame_lists` folder (see structure above).
+
+6. Download person boxes that are generated using a person detector trained on AVA - ([train](https://dl.fbaipublicfiles.com/pytorchvideo/data/ava/ava_detection_test.csv), [val](https://dl.fbaipublicfiles.com/pytorchvideo/data/ava/ava_detection_val.csv), [test](https://dl.fbaipublicfiles.com/pytorchvideo/data/ava/ava_detection_test.csv)) and put them in the `annotations` folder (see structure above). Copy files to the annotations directory mentioned in step 4.
+If you prefer to use your own person detector, please generate detection predictions files in the suggested format in step 6.
+
+Download the ava dataset with the following structure:
+
+```
+ava
+|_ frames
+| |_ [video name 0]
+| | |_ [video name 0]_000001.jpg
+| | |_ [video name 0]_000002.jpg
+| | |_ ...
+| |_ [video name 1]
+| |_ [video name 1]_000001.jpg
+| |_ [video name 1]_000002.jpg
+| |_ ...
+|_ frame_lists
+| |_ train.csv
+| |_ val.csv
+|_ annotations
+ |_ [official AVA annotation files]
+ |_ ava_train_predicted_boxes.csv
+ |_ ava_val_predicted_boxes.csv
+```
diff --git a/code/pytorchvideo/docs/source/index.rst b/code/pytorchvideo/docs/source/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..adaac612f4184b86b1811751aa23949719903576
--- /dev/null
+++ b/code/pytorchvideo/docs/source/index.rst
@@ -0,0 +1,47 @@
+.. pytorchvideo documentation master file, created by
+ sphinx-quickstart on Tue Feb 23 17:19:36 2021.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+:github_url: https://github.com/facebookresearch/pytorchvideo/
+
+
+PyTorchVideo Documentation
+========================================
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Models
+
+ models
+ model_zoo
+ api/models/index
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Data
+
+ data
+ data_preparation
+ api/data/index
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Transforms
+
+ transforms
+ api/transforms/index
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Layers
+
+ layers
+ api/layers/index
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Accelerator
+
+ accelerator
+
diff --git a/code/pytorchvideo/docs/source/layers.md b/code/pytorchvideo/docs/source/layers.md
new file mode 100644
index 0000000000000000000000000000000000000000..884870bb25cd33fc12e18fc5fd23af61c879bf3f
--- /dev/null
+++ b/code/pytorchvideo/docs/source/layers.md
@@ -0,0 +1,55 @@
+# Overview
+
+
+PyTorchVideo is an open source video understanding library that provides up to date builders for state of the art video understanding backbones, layers, heads, and losses addressing different tasks, including acoustic event detection, action recognition (video classification), action detection (video detection), multimodal understanding (acoustic visual classification), self-supervised learning.
+
+The layers subpackage contains definitions for the following layers and activations:
+
+
+* Layer
+ * [BatchNorm](https://arxiv.org/abs/1502.03167)
+ * [2+1 Conv](https://arxiv.org/abs/1711.11248)
+ * ConCat
+ * MLP
+ * [Nonlocal Net](https://arxiv.org/abs/1711.07971)
+ * Positional Encoding
+ * [Squeeze and Excitation](https://arxiv.org/abs/1709.01507)
+ * [Swish](https://arxiv.org/abs/1710.05941)
+
+## Build standard models
+
+PyTorchVideo provide default builders to construct state-of-the-art video understanding layers and activations.
+
+
+### Layers
+
+You can construct a layer with random weights by calling its constructor:
+
+```
+import pytorchvideo.layers as layers
+
+nonlocal = layers.create_nonlocal(dim_in=256, dim_inner=128)
+swish = layers.Swish()
+conv_2plus1d = layers.create_conv_2plus1d(in_channels=256, out_channels=512)
+```
+
+You can verify whether you have built the model successfully by:
+
+```
+import pytorchvideo.layers as layers
+
+nonlocal = layers.create_nonlocal(dim_in=256, dim_inner=128)
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = nonlocal(input_tensor)
+
+swish = layers.Swish()
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = swish(input_tensor)
+
+conv_2plus1d = layers.create_conv_2plus1d(in_channels=256, out_channels=512)
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = conv_2plus1d(input_tensor)
+```
diff --git a/code/pytorchvideo/docs/source/model_zoo.md b/code/pytorchvideo/docs/source/model_zoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..f1b79c650a27fd61a8b000fdfa4fb1fedd6348bf
--- /dev/null
+++ b/code/pytorchvideo/docs/source/model_zoo.md
@@ -0,0 +1,81 @@
+
+
+
+## Model Zoo and Benchmarks
+
+PyTorchVideo provides reference implementation of a large number of video understanding approaches. In this document, we also provide comprehensive benchmarks to evaluate the supported models on different datasets using standard evaluation setup. All the models can be downloaded from the provided links.
+
+### Kinetics-400
+
+arch | depth | pretrain | frame length x sample rate | top 1 | top 5 | Flops (G) x views | Params (M) | Model
+-------- | ----- | -------- | -------------------------- | ----- | ----- | ----------------- | ---------- | --------------------------------------------------------------------------------------------------
+C2D | R50 | \- | 8x8 | 71.46 | 89.68 | 25.89 x 3 x 10 | 24.33 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/C2D\_8x8\_R50.pyth)
+I3D | R50 | \- | 8x8 | 73.27 | 90.70 | 37.53 x 3 x 10 | 28.04 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/I3D\_8x8\_R50.pyth)
+Slow | R50 | \- | 4x16 | 72.40 | 90.18 | 27.55 x 3 x 10 | 32.45 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/SLOW\_4x16\_R50.pyth)
+Slow | R50 | \- | 8x8 | 74.58 | 91.63 | 54.52 x 3 x 10 | 32.45 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/SLOW\_8x8\_R50.pyth)
+SlowFast | R50 | \- | 4x16 | 75.34 | 91.89 | 36.69 x 3 x 10 | 34.48 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/SLOWFAST\_4x16\_R50.pyth)
+SlowFast | R50 | \- | 8x8 | 76.94 | 92.69 | 65.71 x 3 x 10 | 34.57 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/SLOWFAST\_8x8\_R50.pyth)
+SlowFast | R101 | \- | 8x8 | 77.90 | 93.27 | 127.20 x 3 x 10 | 62.83 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/SLOWFAST\_8x8\_R101.pyth)
+SlowFast | R101 | \- | 16x8 | 78.70 | 93.61 | 215.61 x 3 x 10 | 53.77 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/SLOWFAST\_16x8\_R101_50_50.pyth)
+CSN | R101 | \- | 32x2 | 77.00 | 92.90 | 75.62 x 3 x 10 | 22.21 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/CSN\_32x2\_R101.pyth)
+R(2+1)D | R50 | \- | 16x4 | 76.01 | 92.23 | 76.45 x 3 x 10 | 28.11 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/R2PLUS1D\_16x4\_R50.pyth)
+X3D | XS | \- | 4x12 | 69.12 | 88.63 | 0.91 x 3 x 10 | 3.79 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/X3D\_XS.pyth)
+X3D | S | \- | 13x6 | 73.33 | 91.27 | 2.96 x 3 x 10 | 3.79 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/X3D\_S.pyth)
+X3D | M | \- | 16x5 | 75.94 | 92.72 | 6.72 x 3 x 10 | 3.79 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/X3D\_M.pyth)
+X3D | L | \- | 16x5 | 77.44 | 93.31 | 26.64 x 3 x 10 | 6.15 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/X3D\_L.pyth)
+MViT | B | \- | 16x4 | 78.85 | 93.85 | 70.80 x 1 x 5 | 36.61 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/MVIT\_B\_16x4.pyth)
+MViT | B | \- | 32x3 | 80.30 | 94.69 | 170.37 x 1 x 5 | 36.61 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/MVIT\_B\_32x3\_f294077834.pyth)
+
+### Something-Something V2
+
+| arch | depth | pretrain | frame length x sample rate | top 1 | top 5 | Flops (G) x views | Params (M) | Model |
+| -------- | ----- | ------------ | -------------------------- | ----- | ----- | ----------------- | ---------- | ----- |
+| Slow | R50 | Kinetics 400 | 8x8 | 60.04 | 85.19 | 55.10 x 3 x 1 | 31.96 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/ssv2/SLOW\_8x8\_R50.pyth) |
+| SlowFast | R50 | Kinetics 400 | 8x8 | 61.68 | 86.92 | 66.60 x 3 x 1 | 34.04 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/ssv2/SLOWFAST\_8x8\_R50.pyth) |
+
+
+### Charades
+
+| arch | depth | pretrain | frame length x sample rate | MAP | Flops (G) x views | Params (M) | Model |
+| -------- | ----- | ------------ | -------------------------- | ----- | ----------------- | ---------- | ----- |
+| Slow | R50 | Kinetics 400 | 8x8 | 34.72 | 55.10 x 3 x 10 | 31.96 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/charades/SLOW\_8x8\_R50.pyth) |
+| SlowFast | R50 | Kinetics 400 | 8x8 | 37.24 | 66.60 x 3 x 10 | 34.00 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/charades/SLOWFAST\_8x8\_R50.pyth) |
+
+
+### AVA (V2.2)
+
+| arch | depth | pretrain | frame length x sample rate | MAP | Params (M) | Model |
+| -------- | ----- | ------------ | -------------------------- | ----- | ---------- | ----- |
+| Slow | R50 | Kinetics 400 | 4x16 | 19.5 | 31.78 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/ava/SLOW\_4x16\_R50\_DETECTION.pyth) |
+| SlowFast | R50 | Kinetics 400 | 8x8 | 24.67 | 33.82 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/ava/SLOWFAST\_8x8\_R50\_DETECTION.pyth) |
+
+
+### Using PyTorchVideo model zoo
+We provide several different ways to use PyTorchVideo model zoo.
+* The models have been integrated into TorchHub, so could be loaded with TorchHub with or without pre-trained models. Additionally, we provide a [tutorial](https://pytorchvideo.org/docs/tutorial_torchhub_inference) which goes over the steps needed to load models from TorchHub and perform inference.
+* PyTorchVideo models/datasets are also supported in PySlowFast. You can use [PySlowFast workflow](https://github.com/facebookresearch/SlowFast/) to train or test PyTorchVideo models/datasets.
+* You can also use [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) to build training/test pipeline for PyTorchVideo models and datasets. Please check this [tutorial](https://pytorchvideo.org/docs/tutorial_classification) for more information.
+
+
+Notes:
+* The above benchmarks are conducted by [PySlowFast workflow](https://github.com/facebookresearch/SlowFast/) using PyTorchVideo datasets and models.
+* For more details on the data preparation, you can refer to [PyTorchVideo Data Preparation](data_preparation.md).
+* For `Flops x views` column, we report the inference cost with a single “view" × the number of views (FLOPs × space_views × time_views). For example, we take 3 spatial crops for 10 temporal clips on Kinetics.
+
+
+
+### PytorchVideo Accelerator Model Zoo
+Accelerator model zoo provides a set of efficient models on target device with pretrained checkpoints. To learn more about how to build model, load checkpoint and deploy, please refer to [Use PyTorchVideo/Accelerator Model Zoo](https://pytorchvideo.org/docs/tutorial_accelerator_use_accelerator_model_zoo).
+
+**Efficient Models for mobile CPU**
+All top1/top5 accuracies are measured with 10-clip evaluation. Latency is benchmarked on Samsung S8 phone with 1s input clip length.
+
+| model | model builder | top 1 | top 5 | latency (ms) | params (M) | checkpoint |
+|--------------|--------------------------------------------------------------------------|-------|-------|--------------|----------------|---------------------|
+| X3D_XS (fp32)| models. accelerator. mobile_cpu. efficient_x3d. EfficientX3d (expansion="XS") | 68.5 | 88.0 | 233 | 3.8 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/efficient_x3d_xs_original_form.pyth) |
+| X3D_XS (int8)| N/A (Use the TorchScript file in checkpoint link directly) | 66.9 | 87.2 | 165 | 3.8 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/x3d_xs_efficient_converted_qnnpack.pt) |
+| X3D_S (fp32) | models. accelerator. mobile_cpu. efficient_x3d. EfficientX3d (expansion="S") | 73.0 | 90.6 | 764 | 3.8 | [link](https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/efficient_x3d_s_original_form.pyth) |
+
+
+### TorchHub models
+We provide a large set of [TorchHub](https://pytorch.org/hub/) models for the above video models with pre-trained weights. So it's easy to construct the networks and load pre-trained weights. Please refer to [PytorchVideo TorchHub models](https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/models/hub/README.md) for more details.
diff --git a/code/pytorchvideo/docs/source/models.md b/code/pytorchvideo/docs/source/models.md
new file mode 100644
index 0000000000000000000000000000000000000000..3905092fb07f50da2bf9259e6c2ef739bd0ba2b4
--- /dev/null
+++ b/code/pytorchvideo/docs/source/models.md
@@ -0,0 +1,181 @@
+# Overview
+
+
+PyTorchVideo is an open source video understanding library that provides up to date builders for state of the art video understanding backbones, layers, heads, and losses addressing different tasks, including acoustic event detection, action recognition (video classification), action detection (video detection), multimodal understanding (acoustic visual classification), self-supervised learning.
+
+The models subpackage contains definitions for the following model architectures and layers:
+
+
+* Acoustic Backbone
+ * Acoustic ResNet
+* Visual Backbone
+ * [I3D](https://arxiv.org/pdf/1705.07750.pdf)
+ * [C2D](https://arxiv.org/pdf/1711.07971.pdf)
+ * [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/1709.01507.pdf)
+ * [Nonlocal Networks](https://arxiv.org/pdf/1711.07971.pdf)
+ * [R2+1D](https://openaccess.thecvf.com/content_cvpr_2018/papers/Tran_A_Closer_Look_CVPR_2018_paper.pdf)
+ * CSN
+ * [SlowFast](https://arxiv.org/pdf/1812.03982.pdf)
+ * [Audiovisual SlowFast](https://arxiv.org/pdf/2001.08740.pdf)
+ * [X3D](https://arxiv.org/pdf/2004.04730.pdf)
+* Self-Supervised Learning
+ * [SimCLR](https://arxiv.org/pdf/2002.05709.pdf)
+ * [Bootstrap Your Own Latent](https://arxiv.org/pdf/2006.07733.pdf)
+ * [Non-Parametric Instance Discrimination](https://openaccess.thecvf.com/content_cvpr_2018/CameraReady/0801.pdf)
+
+
+## Build standard models
+
+PyTorchVideo provide default builders to construct state-of-the-art video understanding models, layers, heads, and losses.
+
+### Models
+
+You can construct a model with random weights by calling its constructor:
+
+```
+import pytorchvideo.models as models
+
+resnet = models.create_resnet()
+acoustic_resnet = models.create_acoustic_resnet()
+slowfast = models.create_slowfast()
+x3d = models.create_x3d()
+r2plus1d = models.create_r2plus1d()
+csn = models.create_csn()
+```
+
+You can verify whether you have built the model successfully by:
+
+```
+import pytorchvideo.models as models
+
+resnet = models.create_resnet()
+B, C, T, H, W = 2, 3, 8, 224, 224
+input_tensor = torch.zeros(B, C, T, H, W)
+output = resnet(input_tensor)
+```
+
+### Layers
+
+You can construct a layer with random weights by calling its constructor:
+
+```
+import pytorchvideo.layers as layers
+
+nonlocal = layers.create_nonlocal(dim_in=256, dim_inner=128)
+swish = layers.Swish()
+conv_2plus1d = layers.create_conv_2plus1d(in_channels=256, out_channels=512)
+```
+
+You can verify whether you have built the model successfully by:
+
+```
+import pytorchvideo.layers as layers
+
+nonlocal = layers.create_nonlocal(dim_in=256, dim_inner=128)
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = nonlocal(input_tensor)
+
+swish = layers.Swish()
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = swish(input_tensor)
+
+conv_2plus1d = layers.create_conv_2plus1d(in_channels=256, out_channels=512)
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = conv_2plus1d(input_tensor)
+```
+
+### Heads
+
+You can construct a head with random weights by calling its constructor:
+
+```
+import pytorchvideo.models as models
+
+res_head = models.head.create_res_basic_head(in_features, out_features)
+x3d_head = models.x3d.create_x3d_head(dim_in=1024, dim_inner=512, dim_out=2048, num_classes=400)
+```
+
+You can verify whether you have built the head successfully by:
+
+```
+import pytorchvideo.models as models
+
+res_head = models.head.create_res_basic_head(in_features, out_features)
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = res_head(input_tensor)
+
+x3d_head = models.x3d.create_x3d_head(dim_in=1024, dim_inner=512, dim_out=2048, num_classes=400)
+B, C, T, H, W = 2, 256, 4, 14, 14
+input_tensor = torch.zeros(B, C, T, H, W)
+output = x3d_head(input_tensor)
+```
+
+### Losses
+
+You can construct a loss by calling its constructor:
+
+```
+import pytorchvideo.models as models
+
+simclr_loss = models.SimCLR()
+```
+
+You can verify whether you have built the loss successfully by:
+
+```
+import pytorchvideo.models as models
+import pytorchvideo.layers as layers
+
+resnet = models.create_resnet()
+mlp = layers.make_multilayer_perceptron(fully_connected_dims=(2048, 1024, 2048))
+simclr_loss = models.SimCLR(mlp=mlp, backbone=resnet)
+B, C, T, H, W = 2, 256, 4, 14, 14
+view1, view2 = torch.zeros(B, C, T, H, W), torch.zeros(B, C, T, H, W)
+loss = simclr_loss(view1, view2)
+```
+
+## Build customized models
+
+PyTorchVideo also supports building models with customized components, which is an important feature for video understanding research. Here we take a standard stem model as an example, show how to build each resnet components (head, backbone, stem) separately, and how to use your customized components to replace standard components.
+
+
+```
+from pytorchvideo.models.stem import create_res_basic_stem
+
+
+# Create standard stem layer.
+stem = create_res_basic_stem(in_channels=3, out_channels=64)
+
+# Create customized stem layer with YourFancyNorm
+stem = create_res_basic_stem(
+ in_channels=3,
+ out_channels=64,
+ norm=YourFancyNorm, # GhostNorm for example
+)
+
+# Create customized stem layer with YourFancyConv
+stem = create_res_basic_stem(
+ in_channels=3,
+ out_channels=64,
+ conv=YourFancyConv, # OctConv for example
+)
+
+# Create customized stem layer with YourFancyAct
+stem = create_res_basic_stem(
+ in_channels=3,
+ out_channels=64,
+ activation=YourFancyAct, # Swish for example
+)
+
+# Create customized stem layer with YourFancyPool
+stem = create_res_basic_stem(
+ in_channels=3,
+ out_channels=64,
+ pool=YourFancyPool, # MinPool for example
+)
+
+```
diff --git a/code/pytorchvideo/docs/source/transforms.md b/code/pytorchvideo/docs/source/transforms.md
new file mode 100644
index 0000000000000000000000000000000000000000..e107e5c099b93edfdf40fc6e8af2faaf11c7597c
--- /dev/null
+++ b/code/pytorchvideo/docs/source/transforms.md
@@ -0,0 +1,33 @@
+# Overview
+
+The PyTorchVideo transforms package contains common video algorithms used for preprocessing and/or augmenting video data. The package also contains helper dictionary transforms that are useful for interoperability between PyTorchVideo [dataset's clip outputs](https://pytorchvideo.readthedocs.io/en/latest/data.html) and domain specific transforms. For example, here is a standard transform pipeline for a video model, that could be used with a PyTorchVideo dataset:
+
+```python
+transform = torchvision.transforms.Compose([
+ pytorchvideo.transforms.ApplyTransformToKey(
+ key="video",
+ transform=torchvision.transforms.Compose([
+ pytorchvideo.transforms.UniformTemporalSubsample(8),
+ pytorchvideo.transforms.Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
+ pytorchvideo.transforms.RandomShortSideScale(min_size=256, max_size=320),
+ torchvision.transforms.RandomCrop(244),
+ torchvision.transforms.RandomHorizontalFlip(p=0.5),
+ )]
+ )
+])
+dataset = pytorchvideo.data.Kinetics(
+ data_path="path/to/kinetics_root/train.csv",
+ clip_sampler=pytorchvideo.data.make_clip_sampler("random", duration=2),
+ transform=transform
+)
+```
+
+Notice how the example also includes transforms from TorchVision? PyTorchVideo uses the same canonical tensor shape as TorchVision for video and TorchAudio for audio. This allows the frameworks to be used together freely.
+
+## Transform vs Functional interface
+
+The example above demonstrated the [```pytorchvideo.transforms```](https://pytorchvideo.readthedocs.io/en/latest/api/transforms/transforms.html) interface. These transforms are [```torch.nn.module```](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) callable classes that can be stringed together in a declarative way. PyTorchVideo also provides a [```pytorchvideo.transforms.functional```](https://pytorchvideo.readthedocs.io/en/latest/api/transforms/transforms.html#pytorchvideo-transforms-functional) interface, which are the functions that the transform API uses. These allow more fine-grained control over the transformations and may be more suitable for use outside the dataset preprocessing use case.
+
+## Scriptable transforms
+
+All non-OpenCV transforms are TorchScriptable, as described in the [TorchVision docs](https://pytorch.org/vision/stable/transforms.html#scriptable-transforms), in order to script the transforms together, please use [```ltorch.nn.Sequential```](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) instead of [```torchvision.transform.Compose```](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Compose).
diff --git a/code/pytorchvideo/hubconf.py b/code/pytorchvideo/hubconf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22fdb99e810ac34c05f185d259fff3ea31c3911
--- /dev/null
+++ b/code/pytorchvideo/hubconf.py
@@ -0,0 +1,24 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+dependencies = ["torch"]
+from pytorchvideo.models.hub import ( # noqa: F401, E402
+ c2d_r50,
+ csn_r101,
+ efficient_x3d_s,
+ efficient_x3d_xs,
+ i3d_r50,
+ mvit_base_16,
+ mvit_base_16x4,
+ mvit_base_32x3,
+ r2plus1d_r50,
+ slow_r50,
+ slow_r50_detection,
+ slowfast_16x8_r101_50_50,
+ slowfast_r101,
+ slowfast_r50,
+ slowfast_r50_detection,
+ x3d_l,
+ x3d_m,
+ x3d_s,
+ x3d_xs,
+)
diff --git a/code/pytorchvideo/projects/video_nerf/README.md b/code/pytorchvideo/projects/video_nerf/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..abfb2a3865ca59ce0338d679dfe8f3611d2b9555
--- /dev/null
+++ b/code/pytorchvideo/projects/video_nerf/README.md
@@ -0,0 +1,136 @@
+# Train a NeRF model with PyTorchVideo and PyTorch3D
+
+This project demonstrates how to use the video decoder from PyTorchVideo to load frames from a video of an object from the [Objectron dataset](https://github.com/google-research-datasets/Objectron), and use this to train a NeRF [1] model with [PyTorch3D](https://github.com/facebookresearch/pytorch3d). Instead of decoding and storing all the video frames as images, PyTorchVideo offers an easy alternative to load and access frames on the fly. For this project we will be using the [NeRF implementation from PyTorch3D](https://github.com/facebookresearch/pytorch3d/tree/main/projects/nerf).
+
+### Set up
+
+#### Installation
+
+Install PyTorch3D
+
+```python
+# Create new conda environment
+conda create -n 3ddemo
+conda activate 3ddemo
+
+# Install PyTorch3D
+conda install -c pytorch pytorch=1.7.1 torchvision cudatoolkit=10.1
+conda install -c conda-forge -c fvcore -c iopath fvcore iopath
+conda install pytorch3d -c pytorch3d-nightly
+```
+
+Install PyTorchVideo if you haven't installed it already (assuming you have cloned the repo locally):
+
+```python
+cd pytorchvideo
+python -m pip install -e .
+```
+
+Install some extras libraries needed for NeRF:
+
+```python
+pip install visdom Pillow matplotlib tqdm plotly
+pip install hydra-core --upgrade
+```
+
+#### Set up NeRF Model
+
+We will be using the PyTorch3D NeRF implementation. We have already installed the PyTorch3d conda packages, so now we only need to clone the NeRF implementation:
+
+```python
+cd pytorchvideo/tutorials/video_nerf
+git clone https://github.com/facebookresearch/pytorch3d.git
+cp -r pytorch3d/projects/nerf .
+
+# Remove the rest of the PyTorch3D repo
+rm -r pytorch3d
+```
+
+#### Dataset
+
+###### Download the Objectron repo
+
+The repo contains helper functions for reading the metadata files. Clone it to the path `pytorchvideo/tutorials/video_nerf/Objectron`.
+
+```python
+git clone https://github.com/google-research-datasets/Objectron.git
+
+# Also install protobuf for parsing the metadata
+pip install protobuf
+```
+
+###### Download an example video
+
+For this demo we will be using a short video of a chair from the [Objectron dataset](https://github.com/google-research-datasets/Objectron). Each video is accompanied by metadata with the camera parameters for each frame. You can download an example video for a chair and the associated metadata by running the following script:
+
+```python
+python download_objectron_data.py
+```
+
+The data files will be downloaded to the path: `pytorchvideo/tutorials/video_nerf/nerf/data/objectron`. Within the script you can change the index of the video to use to obtain a different chair video. We will create and save a random split of train/val/test when the video is first loaded by the NeRF model training script.
+
+Most of the videos are recorded in landscape mode with image size (H, W) = [1440, 1920].
+
+
+#### Set up new configs
+
+For this dataset we need a new config file and data loader to use it with the PyTorch3D NeRF implementation. Copy the relevant dataset and config files into the `nerf` folder and replace the original files:
+
+```python
+# Make sure you are at the path: pytorchvideo/tutorials/video_nerf
+# Rename the current dataset file
+mv nerf/nerf/dataset.py nerf/nerf/nerf_dataset.py
+
+# Move the new objectron specific files into the nerf folder
+mv dataset.py nerf/nerf/dataset.py
+mv dataset_utils.py nerf/nerf/dataset_utils.py
+mv objectron.yaml nerf/configs
+```
+
+In the `video_dataset.py` file we use the PyTorchVideo `EncodedVideo` class to load a video `.MOV` file, decode it into frames and access the frames by the index.
+
+#### Train model
+
+Run the model training:
+
+```python
+cd nerf
+python ./train_nerf.py --config-name objectron
+```
+
+#### Visualize predictions
+
+Predictions and metrics will be logged to Visdom. Before training starts launch the visdom server:
+
+```python
+python -m visdom.server
+```
+
+Navigate to `https://localhost:8097` to view the logs and visualizations.
+
+After training, you can generate predictions on the test set:
+
+```python
+python test_nerf.py --config-name objectron test.mode='export_video' data.image_size="[96,128]"
+```
+
+For a higher resolution video you can increase the image size to e.g. [192, 256] (note that this will slow down inference).
+
+You will need to specify the `scene_center` for the video in the `objectron.yaml` file. This is set for the demo video specified in `download_objectron_data.py`. For a different video you can calculate the scene center inside [`eval_video_utils.py`](https://github.com/facebookresearch/pytorch3d/blob/main/projects/nerf/nerf/eval_video_utils.py#L99). After line 99 you can add the following code to compute the center:
+
+```python
+# traj is the circular camera trajectory on the camera mean plane.
+# We want the camera to always point towards the center of this trajectory.
+x_center = traj[..., 0].mean().item()
+z_center = traj[..., 2].mean().item()
+y_center = traj[0, ..., 1]
+scene_center = [x_center, y_center, z_center]
+```
+You can also point the camera down/up relative to the camera mean plane e.g. `y_center -= 0.5`
+
+Here is an example of a video reconstruction generated using a trained NeRF model. NOTE: the quality of reconstruction is highly dependent on the camera pose range and accuracy in the annotations - try training a model for a few different chairs in the dataset to see which one has the best results.
+
+
+
+##### References
+[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng, NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, ECCV2020
diff --git a/code/pytorchvideo/projects/video_nerf/dataset.py b/code/pytorchvideo/projects/video_nerf/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e687d2c6d77f9a4d503dcab25a6aa7c4a64d0348
--- /dev/null
+++ b/code/pytorchvideo/projects/video_nerf/dataset.py
@@ -0,0 +1,125 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+
+import os
+from typing import Tuple
+
+import numpy as np
+import torch
+import tqdm
+
+# Imports from PyTorchVideo and PyTorch3D
+from pytorch3d.renderer import PerspectiveCameras
+from pytorchvideo.data.encoded_video import EncodedVideo
+from torch.utils.data import Dataset
+
+from .dataset_utils import (
+ generate_splits,
+ get_geometry_data,
+ objectron_to_pytorch3d,
+ resize_images,
+)
+from .nerf_dataset import ListDataset
+
+
+DEFAULT_DATA_ROOT = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "..", "data", "objectron"
+)
+
+
+def trivial_collate(batch):
+ """
+ A trivial collate function that merely returns the uncollated batch.
+ """
+ return batch
+
+
+def get_nerf_datasets(
+ dataset_name: str,
+ image_size: Tuple[int, int],
+ data_root: str = DEFAULT_DATA_ROOT,
+ **kwargs,
+) -> Tuple[Dataset, Dataset, Dataset]:
+ """
+ Obtains the training and validation dataset object for a dataset specified
+ with the `dataset_name` argument.
+
+ Args:
+ dataset_name: The name of the dataset to load.
+ image_size: A tuple (height, width) denoting the sizes of the loaded dataset images.
+ data_root: The root folder at which the data is stored.
+
+ Returns:
+ train_dataset: The training dataset object.
+ val_dataset: The validation dataset object.
+ test_dataset: The testing dataset object.
+ """
+ print(f"Loading dataset {dataset_name}, image size={str(image_size)} ...")
+
+ if dataset_name != "objectron":
+ raise ValueError("This data loader is only for the objectron dataset")
+
+ # Use the bundle adjusted camera parameters
+ sequence_geometry = get_geometry_data(os.path.join(data_root, "sfm_arframe.pbdata"))
+ num_frames = len(sequence_geometry)
+
+ # Check if splits are present else generate them on the first instance:
+ splits_path = os.path.join(data_root, "splits.pt")
+ if os.path.exists(splits_path):
+ print("Loading splits...")
+ splits = torch.load(splits_path)
+ train_idx, val_idx, test_idx = splits
+ else:
+ print("Generating splits...")
+ index_options = np.arange(num_frames)
+ train_idx, val_idx, test_idx = generate_splits(index_options)
+ torch.save([train_idx, val_idx, test_idx], splits_path)
+
+ print("Loading video...")
+ video_path = os.path.join(data_root, "video.MOV")
+ # Load the video using the PyTorchVideo video class
+ video = EncodedVideo.from_path(video_path)
+ FPS = 30
+
+ print("Loading all images and cameras...")
+ # Load all the video frames
+ frame_data = video.get_clip(start_sec=0, end_sec=(num_frames - 1) * 1.0 / FPS)
+ frame_data = frame_data["video"].permute(1, 2, 3, 0)
+ images = resize_images(frame_data, image_size)
+ cameras = []
+
+ for frame_id in tqdm.tqdm(range(num_frames)):
+ I, P = sequence_geometry[frame_id]
+ R = P[0:3, 0:3]
+ T = P[0:3, 3]
+
+ # Convert conventions
+ R = R.transpose(0, 1)
+ R, T = objectron_to_pytorch3d(R, T)
+
+ # Get intrinsic parameters
+ fx = I[0, 0]
+ fy = I[1, 1]
+ px = I[0, 2]
+ py = I[1, 2]
+
+ # Initialize the Perspective Camera
+ scene_cam = PerspectiveCameras(
+ R=R[None, ...],
+ T=T[None, ...],
+ focal_length=((fx, fy),),
+ principal_point=((px, py),),
+ ).to("cpu")
+
+ cameras.append(scene_cam)
+
+ train_dataset, val_dataset, test_dataset = [
+ ListDataset(
+ [
+ {"image": images[i], "camera": cameras[i], "camera_idx": int(i)}
+ for i in idx
+ ]
+ )
+ for idx in [train_idx, val_idx, test_idx]
+ ]
+
+ return train_dataset, val_dataset, test_dataset
diff --git a/code/pytorchvideo/projects/video_nerf/dataset_utils.py b/code/pytorchvideo/projects/video_nerf/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d81737afaf15d93979c5485a7d9e2759c8ca2e70
--- /dev/null
+++ b/code/pytorchvideo/projects/video_nerf/dataset_utils.py
@@ -0,0 +1,117 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+import struct
+import sys
+from typing import List, Tuple
+
+import numpy as np
+import torch
+
+# The AR Metadata captured with each frame in the video
+from objectron.schema import ( # noqa: E402
+ a_r_capture_metadata_pb2 as ar_metadata_protocol,
+)
+from PIL import Image
+from pytorch3d.transforms import Rotate, RotateAxisAngle, Translate
+
+
+# Imports from Objectron
+module_path = os.path.abspath(os.path.join("..."))
+if module_path not in sys.path:
+ sys.path.append("../Objectron")
+
+
+def objectron_to_pytorch3d(
+ R: torch.Tensor, T: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Transforms the R and T matrices from the Objectron world coordinate
+ system to the PyTorch3d world system.
+ Objectron cameras live in +X right, +Y Up, +Z from screen to us.
+ Pytorch3d world is +X left, +Y up, +Z from us to screen.
+ """
+ rotation = Rotate(R=R)
+ conversion = RotateAxisAngle(axis="y", angle=180)
+ composed_transform = rotation.compose(conversion).get_matrix()
+ composed_R = composed_transform[0, 0:3, 0:3]
+
+ translation = Translate(x=T[None, ...])
+ t_matrix = translation.compose(conversion).get_matrix()
+ flipped_T = t_matrix[0, 3, :3]
+ return composed_R, flipped_T
+
+
+def generate_splits(
+ index_options: List[int], train_fraction: float = 0.8
+) -> List[List[int]]:
+ """
+ Get indices for train, val, test splits.
+ """
+ num_images = len(index_options)
+ np.random.shuffle(index_options)
+ train_index = int(train_fraction * num_images)
+ val_index = train_index + ((num_images - train_index) // 2)
+ train_indices = index_options[:train_index]
+ val_indices = index_options[train_index:val_index]
+ test_indices = index_options[val_index:]
+ split_indices = [train_indices, val_indices, test_indices]
+ return split_indices
+
+
+def get_geometry_data(geometry_filename: str) -> List[List[torch.Tensor]]:
+ """
+ Utils function for parsing metadata files from the Objectron GitHub repo:
+ https://github.com/google-research-datasets/Objectron/blob/master/notebooks/objectron-geometry-tutorial.ipynb # noqa: B950
+ """
+ sequence_geometry = []
+ with open(geometry_filename, "rb") as pb:
+ proto_buf = pb.read()
+
+ i = 0
+ while i < len(proto_buf):
+ # Read the first four Bytes in little endian '<' integers 'I' format
+ # indicating the length of the current message.
+ msg_len = struct.unpack(" torch.Tensor:
+ """
+ Utils function to resize images
+ """
+ _image_max_image_pixels = Image.MAX_IMAGE_PIXELS
+ Image.MAX_IMAGE_PIXELS = None # The dataset image is very large ...
+ images = torch.FloatTensor(frames) / 255.0
+ Image.MAX_IMAGE_PIXELS = _image_max_image_pixels
+
+ scale_factors = [s_new / s for s, s_new in zip(images.shape[1:3], image_size)]
+
+ if abs(scale_factors[0] - scale_factors[1]) > 1e-3:
+ raise ValueError(
+ "Non-isotropic scaling is not allowed. Consider changing the 'image_size' argument."
+ )
+ scale_factor = sum(scale_factors) * 0.5
+
+ if scale_factor != 1.0:
+ print(f"Rescaling dataset (factor={scale_factor})")
+ images = torch.nn.functional.interpolate(
+ images.permute(0, 3, 1, 2),
+ size=tuple(image_size),
+ mode="bilinear",
+ ).permute(0, 2, 3, 1)
+
+ return images
diff --git a/code/pytorchvideo/projects/video_nerf/download_objectron_data.py b/code/pytorchvideo/projects/video_nerf/download_objectron_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c201ba2ba928bc1e5fce5e4c6325e6495907d650
--- /dev/null
+++ b/code/pytorchvideo/projects/video_nerf/download_objectron_data.py
@@ -0,0 +1,58 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+
+import requests
+
+
+# URLs for downloading the Objectron dataset
+public_url = "https://storage.googleapis.com/objectron"
+blob_path = public_url + "/v1/index/chair_annotations_train"
+video_ids = requests.get(blob_path).text
+video_ids = video_ids.split("\n")
+
+DATA_PATH = "./nerf/data/objectron"
+
+os.makedirs(DATA_PATH, exist_ok=True)
+
+# Download a video of a chair.
+for i in range(3, 4):
+ video_filename = public_url + "/videos/" + video_ids[i] + "/video.MOV"
+ metadata_filename = public_url + "/videos/" + video_ids[i] + "/geometry.pbdata"
+ annotation_filename = public_url + "/annotations/" + video_ids[i] + ".pbdata"
+
+ # This file contains the bundle adjusted cameras
+ sfm_filename = public_url + "/videos/" + video_ids[i] + "/sfm_arframe.pbdata"
+
+ # video.content contains the video file.
+ video = requests.get(video_filename)
+ metadata = requests.get(metadata_filename)
+
+ # Please refer to Parse Annotation tutorial to see how to parse the annotation files.
+ annotation = requests.get(annotation_filename)
+
+ sfm = requests.get(sfm_filename)
+
+ video_path = os.path.join(DATA_PATH, "video.MOV")
+ print("Writing video to %s" % video_path)
+ file = open(video_path, "wb")
+ file.write(video.content)
+ file.close()
+
+ geometry_path = os.path.join(DATA_PATH, "geometry.pbdata")
+ print("Writing geometry data to %s" % geometry_path)
+ file = open(geometry_path, "wb")
+ file.write(metadata.content)
+ file.close()
+
+ annotation_path = os.path.join(DATA_PATH, "annotation.pbdata")
+ print("Writing annotation data to %s" % annotation_path)
+ file = open(annotation_path, "wb")
+ file.write(annotation.content)
+ file.close()
+
+ sfm_arframe_path = os.path.join(DATA_PATH, "sfm_arframe.pbdata")
+ print("Writing bundle adjusted camera data to %s" % sfm_arframe_path)
+ file = open(sfm_arframe_path, "wb")
+ file.write(sfm.content)
+ file.close()
diff --git a/code/pytorchvideo/projects/video_nerf/objectron.yaml b/code/pytorchvideo/projects/video_nerf/objectron.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3bd1cdad7c874be910f98ddfd037af1f47a9f22a
--- /dev/null
+++ b/code/pytorchvideo/projects/video_nerf/objectron.yaml
@@ -0,0 +1,45 @@
+seed: 3
+resume: True
+stats_print_interval: 10
+validation_epoch_interval: 5
+checkpoint_epoch_interval: 30
+checkpoint_path: 'checkpoints/objectron.pth'
+data:
+ dataset_name: 'objectron'
+ image_size: [1440, 1920] # [height, width]
+ precache_rays: True
+test:
+ mode: 'evaluation'
+ trajectory_type: 'circular'
+ up: [0.0, 1.0, 0.0]
+ scene_center: [-0.5365, -1.05, 7.6191]
+ n_frames: 50
+ fps: 1
+ trajectory_scale: 0.2
+optimizer:
+ max_epochs: 20000
+ lr: 0.0005
+ lr_scheduler_step_size: 5000
+ lr_scheduler_gamma: 0.1
+visualization:
+ history_size: 10
+ visdom: True
+ visdom_server: 'localhost'
+ visdom_port: 8097
+ visdom_env: 'objectron'
+raysampler:
+ n_pts_per_ray: 64
+ n_pts_per_ray_fine: 64
+ n_rays_per_image: 1024
+ min_depth: 0.1
+ max_depth: 100.0
+ stratified: True
+ stratified_test: False
+ chunk_size_test: 6000
+implicit_function:
+ n_harmonic_functions_xyz: 10
+ n_harmonic_functions_dir: 4
+ n_hidden_neurons_xyz: 256
+ n_hidden_neurons_dir: 128
+ density_noise_std: 0.0
+ n_layers_xyz: 8
diff --git a/code/pytorchvideo/pytorchvideo.egg-info/PKG-INFO b/code/pytorchvideo/pytorchvideo.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..b65f333e0da43ab520d7cecc6017ce4c9a1a746d
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo.egg-info/PKG-INFO
@@ -0,0 +1,12 @@
+Metadata-Version: 2.1
+Name: pytorchvideo
+Version: 0.1.5
+Summary: A video understanding deep learning library.
+Home-page: https://github.com/facebookresearch/pytorchvideo
+Author: Facebook AI
+License: Apache 2.0
+Requires-Python: >=3.7
+Provides-Extra: test
+Provides-Extra: dev
+Provides-Extra: opencv-python
+License-File: LICENSE
diff --git a/code/pytorchvideo/pytorchvideo.egg-info/SOURCES.txt b/code/pytorchvideo/pytorchvideo.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c5ce4832f23e8e267af1699927afbb576944081d
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo.egg-info/SOURCES.txt
@@ -0,0 +1,162 @@
+CONTRIBUTING.md
+LICENSE
+MANIFEST.in
+README.md
+setup.cfg
+setup.py
+pytorchvideo/__init__.py
+pytorchvideo.egg-info/PKG-INFO
+pytorchvideo.egg-info/SOURCES.txt
+pytorchvideo.egg-info/dependency_links.txt
+pytorchvideo.egg-info/requires.txt
+pytorchvideo.egg-info/top_level.txt
+pytorchvideo/accelerator/__init__.py
+pytorchvideo/accelerator/deployment/__init__.py
+pytorchvideo/accelerator/deployment/common/__init__.py
+pytorchvideo/accelerator/deployment/common/model_transmuter.py
+pytorchvideo/accelerator/deployment/mobile_cpu/__init__.py
+pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/__init__.py
+pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py
+pytorchvideo/accelerator/deployment/mobile_cpu/utils/__init__.py
+pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py
+pytorchvideo/accelerator/efficient_blocks/__init__.py
+pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py
+pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py
+pytorchvideo/data/__init__.py
+pytorchvideo/data/ava.py
+pytorchvideo/data/charades.py
+pytorchvideo/data/clip_sampling.py
+pytorchvideo/data/dataset_manifest_utils.py
+pytorchvideo/data/decoder.py
+pytorchvideo/data/domsev.py
+pytorchvideo/data/encoded_video.py
+pytorchvideo/data/encoded_video_decord.py
+pytorchvideo/data/encoded_video_pyav.py
+pytorchvideo/data/encoded_video_torchvision.py
+pytorchvideo/data/epic_kitchen_forecasting.py
+pytorchvideo/data/epic_kitchen_recognition.py
+pytorchvideo/data/frame_video.py
+pytorchvideo/data/hmdb51.py
+pytorchvideo/data/json_dataset.py
+pytorchvideo/data/kinetics.py
+pytorchvideo/data/labeled_video_dataset.py
+pytorchvideo/data/labeled_video_paths.py
+pytorchvideo/data/ssv2.py
+pytorchvideo/data/ucf101.py
+pytorchvideo/data/utils.py
+pytorchvideo/data/video.py
+pytorchvideo/data/ego4d/__init__.py
+pytorchvideo/data/ego4d/ego4d_dataset.py
+pytorchvideo/data/ego4d/utils.py
+pytorchvideo/data/epic_kitchen/__init__.py
+pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py
+pytorchvideo/data/epic_kitchen/utils.py
+pytorchvideo/layers/__init__.py
+pytorchvideo/layers/attention.py
+pytorchvideo/layers/attention_torchscript.py
+pytorchvideo/layers/batch_norm.py
+pytorchvideo/layers/convolutions.py
+pytorchvideo/layers/distributed.py
+pytorchvideo/layers/drop_path.py
+pytorchvideo/layers/fusion.py
+pytorchvideo/layers/mlp.py
+pytorchvideo/layers/nonlocal_net.py
+pytorchvideo/layers/positional_encoding.py
+pytorchvideo/layers/positional_encoding_torchscript.py
+pytorchvideo/layers/squeeze_excitation.py
+pytorchvideo/layers/swish.py
+pytorchvideo/layers/utils.py
+pytorchvideo/layers/accelerator/__init__.py
+pytorchvideo/layers/accelerator/mobile_cpu/__init__.py
+pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py
+pytorchvideo/layers/accelerator/mobile_cpu/attention.py
+pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py
+pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py
+pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py
+pytorchvideo/layers/accelerator/mobile_cpu/pool.py
+pytorchvideo/losses/__init__.py
+pytorchvideo/losses/soft_target_cross_entropy.py
+pytorchvideo/models/__init__.py
+pytorchvideo/models/audio_visual_slowfast.py
+pytorchvideo/models/byol.py
+pytorchvideo/models/csn.py
+pytorchvideo/models/head.py
+pytorchvideo/models/masked_multistream.py
+pytorchvideo/models/memory_bank.py
+pytorchvideo/models/net.py
+pytorchvideo/models/r2plus1d.py
+pytorchvideo/models/resnet.py
+pytorchvideo/models/simclr.py
+pytorchvideo/models/slowfast.py
+pytorchvideo/models/stem.py
+pytorchvideo/models/vision_transformers.py
+pytorchvideo/models/weight_init.py
+pytorchvideo/models/x3d.py
+pytorchvideo/models/accelerator/__init__.py
+pytorchvideo/models/accelerator/mobile_cpu/__init__.py
+pytorchvideo/models/accelerator/mobile_cpu/efficient_x3d.py
+pytorchvideo/models/accelerator/mobile_cpu/residual_blocks.py
+pytorchvideo/models/hub/__init__.py
+pytorchvideo/models/hub/csn.py
+pytorchvideo/models/hub/efficient_x3d_mobile_cpu.py
+pytorchvideo/models/hub/r2plus1d.py
+pytorchvideo/models/hub/resnet.py
+pytorchvideo/models/hub/slowfast.py
+pytorchvideo/models/hub/utils.py
+pytorchvideo/models/hub/vision_transformers.py
+pytorchvideo/models/hub/x3d.py
+pytorchvideo/transforms/__init__.py
+pytorchvideo/transforms/augmentations.py
+pytorchvideo/transforms/augmix.py
+pytorchvideo/transforms/functional.py
+pytorchvideo/transforms/mix.py
+pytorchvideo/transforms/rand_augment.py
+pytorchvideo/transforms/transforms.py
+pytorchvideo/transforms/transforms_factory.py
+tests/test_accelerator_deployment_mobile_cpu_model_conversion.py
+tests/test_accelerator_deployment_model_transmuter.py
+tests/test_accelerator_efficient_blocks_mobile_cpu_activation_attention.py
+tests/test_accelerator_efficient_blocks_mobile_cpu_conv3d.py
+tests/test_accelerator_efficient_blocks_mobile_cpu_head_layer.py
+tests/test_accelerator_efficient_blocks_mobile_cpu_residual_block.py
+tests/test_accelerator_models_efficient_x3d.py
+tests/test_data_ava_dataset.py
+tests/test_data_charades_dataset.py
+tests/test_data_dataset_manifest_utils.py
+tests/test_data_domsev_dataset.py
+tests/test_data_encoded_video.py
+tests/test_data_epic_kitchen_dataset.py
+tests/test_data_epic_kitchen_forecasting.py
+tests/test_data_epic_kitchen_recognition.py
+tests/test_data_epic_kitchen_utils.py
+tests/test_data_frame_video.py
+tests/test_data_json_dataset.py
+tests/test_data_labeled_video_dataset.py
+tests/test_data_ssv2_dataset.py
+tests/test_data_utils.py
+tests/test_fuse_bn.py
+tests/test_layers_attention.py
+tests/test_layers_convolutions.py
+tests/test_layers_drop_path.py
+tests/test_layers_fusion.py
+tests/test_layers_mlp.py
+tests/test_layers_nonlocal_net.py
+tests/test_layers_positional_encoding.py
+tests/test_layers_squeeze_excitation.py
+tests/test_losses_soft_target_cross_entropy.py
+tests/test_models_audio_visual_slowfast.py
+tests/test_models_byol.py
+tests/test_models_csn.py
+tests/test_models_head.py
+tests/test_models_hub_vision_transformers.py
+tests/test_models_masked_multistream.py
+tests/test_models_memory_bank.py
+tests/test_models_r2plus1d.py
+tests/test_models_resnet.py
+tests/test_models_slowfast.py
+tests/test_models_stem.py
+tests/test_models_vision_transformers.py
+tests/test_models_x3d.py
+tests/test_simclr.py
+tests/test_transforms.py
+tests/test_uniform_clip_sampler.py
\ No newline at end of file
diff --git a/code/pytorchvideo/pytorchvideo.egg-info/dependency_links.txt b/code/pytorchvideo/pytorchvideo.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/code/pytorchvideo/pytorchvideo.egg-info/requires.txt b/code/pytorchvideo/pytorchvideo.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7b42a71885c62345079daac67c4f225f1e8c4928
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo.egg-info/requires.txt
@@ -0,0 +1,28 @@
+fvcore
+av
+parameterized
+iopath
+networkx
+
+[dev]
+opencv-python
+decord
+black==20.8b1
+sphinx
+isort==4.3.21
+flake8==3.8.1
+flake8-bugbear
+flake8-comprehensions
+pre-commit
+nbconvert
+bs4
+autoflake==1.4
+
+[opencv-python]
+opencv-python
+
+[test]
+coverage
+pytest
+opencv-python
+decord
diff --git a/code/pytorchvideo/pytorchvideo.egg-info/top_level.txt b/code/pytorchvideo/pytorchvideo.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0b1daa503b36a6d4ec04f865849bc1738f7d12a9
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo.egg-info/top_level.txt
@@ -0,0 +1 @@
+pytorchvideo
diff --git a/code/pytorchvideo/pytorchvideo/__init__.py b/code/pytorchvideo/pytorchvideo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2b2f870710dc40c0e8f1910bcd0a089a8ecf018
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+__version__ = "0.1.5"
diff --git a/code/pytorchvideo/pytorchvideo/__pycache__/__init__.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35f3c4cb422de0e9bde61274435508cf01891cc3
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/__pycache__/__init__.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/__init__.py b/code/pytorchvideo/pytorchvideo/accelerator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/__init__.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/common/__init__.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/common/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/common/model_transmuter.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/common/model_transmuter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1593528b3212b1f4bb926d002e0e9c64c113b8e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/common/model_transmuter.py
@@ -0,0 +1,86 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+from typing import List
+
+import torch.nn as nn
+
+
+"""
+This file contains top-level transmuter to convert user input model (nn.Module) into
+an equivalent model composed of efficientBlocks for target device.
+Specifically, each target device has a transmuter list, which contains transmuter
+functions to convert module into equivalent efficientBlock. Each transmuter list is
+registered in EFFICIENT_BLOCK_TRANSMUTER_REGISTRY to be accessed by top-level transmuter.
+"""
+EFFICIENT_BLOCK_TRANSMUTER_REGISTRY = {}
+
+
+def _find_equivalent_efficient_module(
+ module_input: nn.Module,
+ efficient_block_transmuter_list: List,
+ module_name: str = "",
+):
+ """
+ Given module_input, search through efficient_block_registry to see whether the
+ module_input can be replaced with equivalent efficientBlock. Returns None if no
+ equivalent efficientBlock is found, else returns an instance of equivalent
+ efficientBlock.
+ Args:
+ module_input (nn.Module): module to be replaced by equivalent efficientBlock
+ efficient_block_transmuter_list (list): a transmuter list that contains transmuter
+ functions for available efficientBlocks
+ module_name (str): name of module_input in original model
+ """
+ eq_module_hit_list = []
+ for iter_func in efficient_block_transmuter_list:
+ eq_module = iter_func(module_input)
+ if eq_module is not None:
+ eq_module_hit_list.append(eq_module)
+ if len(eq_module_hit_list) > 0:
+ # Check for multiple matches.
+ if len(eq_module_hit_list) > 1:
+ logging.warning(f"{module_name} has multiple matches:")
+ for iter_match in eq_module_hit_list:
+ logging.warning(f"{iter_match.__class__.__name__} is a match.")
+ logging.warning(
+ f"Will use {eq_module_hit_list[0]} as it has highest priority."
+ )
+ return eq_module_hit_list[0]
+ return None
+
+
+def transmute_model(
+ model: nn.Module,
+ target_device: str = "mobile_cpu",
+ prefix: str = "",
+):
+ """
+ Recursively goes through user input model and replace module in place with available
+ equivalent efficientBlock for target device.
+ Args:
+ model (nn.Module): user input model to be transmuted
+ target_device (str): name of target device, used to access transmuter list in
+ EFFICIENT_BLOCK_TRANSMUTER_REGISTRY
+ prefix (str): name of current hierarchy in user model
+ """
+ assert (
+ target_device in EFFICIENT_BLOCK_TRANSMUTER_REGISTRY
+ ), f"{target_device} not registered in EFFICIENT_BLOCK_TRANSMUTER_REGISTRY!"
+ transmuter_list = EFFICIENT_BLOCK_TRANSMUTER_REGISTRY[target_device]
+ for name, child in model.named_children():
+ equivalent_module = _find_equivalent_efficient_module(
+ child, transmuter_list, module_name=f"{prefix}.{name}"
+ )
+ if equivalent_module is not None:
+ model._modules[name] = equivalent_module
+ logging.info(
+ f"Replacing {prefix}.{name} ({child.__class__.__name__}) with "
+ f"{equivalent_module.__class__.__name__}"
+ )
+ else:
+ transmute_model(
+ child,
+ target_device=target_device,
+ prefix=f"{prefix}.{name}",
+ )
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/__init__.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/__init__.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c573dc2f395c8a20625ee32398c503aefe69d02
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/__init__.py
@@ -0,0 +1,10 @@
+from pytorchvideo.accelerator.deployment.common.model_transmuter import (
+ EFFICIENT_BLOCK_TRANSMUTER_REGISTRY,
+)
+
+from .transmuter_mobile_cpu import EFFICIENT_BLOCK_TRANSMUTER_MOBILE_CPU
+
+
+EFFICIENT_BLOCK_TRANSMUTER_REGISTRY[
+ "mobile_cpu"
+] = EFFICIENT_BLOCK_TRANSMUTER_MOBILE_CPU
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfaee8a8c7dc141282b64aad7c59039c47484eec
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py
@@ -0,0 +1,204 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import torch.nn as nn
+from pytorchvideo.layers.accelerator.mobile_cpu.convolutions import (
+ Conv3d3x1x1BnAct,
+ Conv3d3x3x3DwBnAct,
+ Conv3d5x1x1BnAct,
+ Conv3dPwBnAct,
+ Conv3dTemporalKernel1BnAct,
+)
+
+
+def transmute_Conv3dPwBnAct(input_module: nn.Module):
+ """
+ Given an input_module, transmutes it into a equivalent Conv3dPwBnAct. Returns None
+ if no equivalent Conv3dPwBnAct is found, else returns an instance of equivalent
+ Conv3dPwBnAct.
+ Args:
+ input_module (nn.Module): input module to find an equivalent Conv3dPwBnAct
+ """
+ if not isinstance(input_module, nn.Conv3d):
+ return None
+ if (
+ input_module.kernel_size == (1, 1, 1)
+ and input_module.groups == 1
+ and input_module.stride == (1, 1, 1)
+ and input_module.padding == (0, 0, 0)
+ and input_module.dilation == (1, 1, 1)
+ ):
+ module = Conv3dPwBnAct(
+ in_channels=input_module.in_channels,
+ out_channels=input_module.out_channels,
+ bias=False if input_module.bias is None else True,
+ activation="identity",
+ use_bn=False,
+ )
+ module.kernel.conv.load_state_dict(input_module.state_dict())
+ return module
+ else:
+ return None
+
+
+def transmute_Conv3d3x3x3DwBnAct(input_module: nn.Module):
+ """
+ Given an input_module, transmutes it into a equivalent Conv3d3x3x3DwBnAct. Returns
+ None if no equivalent Conv3d3x3x3DwBnAct is found, else returns an instance of
+ equivalent Conv3d3x3x3DwBnAct.
+ Args:
+ input_module (nn.Module): input module to find an equivalent Conv3d3x3x3DwBnAct
+ """
+ if not isinstance(input_module, nn.Conv3d):
+ return None
+ if (
+ input_module.kernel_size == (3, 3, 3)
+ and input_module.in_channels == input_module.out_channels
+ and input_module.groups == input_module.out_channels
+ and input_module.stride[0] == 1
+ and input_module.stride[1] == input_module.stride[2]
+ and input_module.padding == (1, 1, 1)
+ and input_module.padding_mode == "zeros"
+ and input_module.dilation == (1, 1, 1)
+ ):
+ spatial_stride = input_module.stride[1]
+ module = Conv3d3x3x3DwBnAct(
+ in_channels=input_module.in_channels,
+ spatial_stride=spatial_stride,
+ bias=False if input_module.bias is None else True,
+ activation="identity",
+ use_bn=False,
+ )
+ module.kernel.conv.load_state_dict(input_module.state_dict())
+ return module
+ else:
+ return None
+
+
+def transmute_Conv3dTemporalKernel1BnAct(input_module: nn.Module):
+ """
+ Given an input_module, transmutes it into a equivalent Conv3dTemporalKernel1BnAct.
+ Returns None if no equivalent Conv3dTemporalKernel1BnAct is found, else returns
+ an instance of equivalent Conv3dTemporalKernel1BnAct.
+ Args:
+ input_module (nn.Module): input module to find an equivalent Conv3dTemporalKernel1BnAct
+ """
+ if not isinstance(input_module, nn.Conv3d):
+ return None
+ """
+ If the input_module can be replaced by Conv3dPwBnAct, don't use
+ Conv3dTemporalKernel1BnAct.
+ """
+ if (
+ input_module.kernel_size == (1, 1, 1)
+ and input_module.groups == 1
+ and input_module.stride == (1, 1, 1)
+ and input_module.padding == (0, 0, 0)
+ and input_module.dilation == (1, 1, 1)
+ ):
+ return None
+
+ if (
+ input_module.kernel_size[0] == 1
+ and input_module.kernel_size[1] == input_module.kernel_size[2]
+ and input_module.stride[0] == 1
+ and input_module.stride[1] == input_module.stride[2]
+ and input_module.padding[0] == 0
+ and input_module.dilation[0] == 1
+ ):
+ spatial_stride = input_module.stride[1]
+ spatial_kernel = input_module.kernel_size[1]
+ spatial_padding = input_module.padding[1]
+ spatial_dilation = input_module.dilation[1]
+ module = Conv3dTemporalKernel1BnAct(
+ in_channels=input_module.in_channels,
+ out_channels=input_module.out_channels,
+ bias=False if input_module.bias is None else True,
+ groups=input_module.groups,
+ spatial_kernel=spatial_kernel,
+ spatial_stride=spatial_stride,
+ spatial_padding=spatial_padding,
+ spatial_dilation=spatial_dilation,
+ activation="identity",
+ use_bn=False,
+ )
+ module.kernel.conv.load_state_dict(input_module.state_dict())
+ return module
+ else:
+ return None
+
+
+def transmute_Conv3d3x1x1BnAct(input_module: nn.Module):
+ """
+ Given an input_module, transmutes it into a equivalent Conv3d3x1x1BnAct.
+ Returns None if no equivalent Conv3d3x1x1BnAct is found, else returns
+ an instance of equivalent Conv3d3x1x1BnAct.
+ Args:
+ input_module (nn.Module): input module to find an equivalent Conv3d3x1x1BnAct
+ """
+ if not isinstance(input_module, nn.Conv3d):
+ return None
+
+ if (
+ input_module.kernel_size == (3, 1, 1)
+ and input_module.stride == (1, 1, 1)
+ and input_module.padding == (1, 0, 0)
+ and input_module.dilation == (1, 1, 1)
+ and input_module.padding_mode == "zeros"
+ ):
+ module = Conv3d3x1x1BnAct(
+ in_channels=input_module.in_channels,
+ out_channels=input_module.out_channels,
+ bias=False if input_module.bias is None else True,
+ groups=input_module.groups,
+ activation="identity",
+ use_bn=False,
+ )
+ module.kernel.conv.load_state_dict(input_module.state_dict())
+ return module
+ else:
+ return None
+
+
+def transmute_Conv3d5x1x1BnAct(input_module: nn.Module):
+ """
+ Given an input_module, transmutes it into a equivalent Conv3d5x1x1BnAct.
+ Returns None if no equivalent Conv3d5x1x1BnAct is found, else returns
+ an instance of equivalent Conv3d5x1x1BnAct.
+ Args:
+ input_module (nn.Module): input module to find an equivalent Conv3d5x1x1BnAct
+ """
+ if not isinstance(input_module, nn.Conv3d):
+ return None
+
+ if (
+ input_module.kernel_size == (5, 1, 1)
+ and input_module.stride == (1, 1, 1)
+ and input_module.padding == (2, 0, 0)
+ and input_module.dilation == (1, 1, 1)
+ and input_module.padding_mode == "zeros"
+ ):
+ module = Conv3d5x1x1BnAct(
+ in_channels=input_module.in_channels,
+ out_channels=input_module.out_channels,
+ bias=False if input_module.bias is None else True,
+ groups=input_module.groups,
+ activation="identity",
+ use_bn=False,
+ )
+ module.kernel.conv.load_state_dict(input_module.state_dict())
+ return module
+ else:
+ return None
+
+
+"""
+List of efficient_block transmuters for mobile_cpu. If one module matches multiple
+transmuters, the first matched transmuter in list will be used.
+"""
+EFFICIENT_BLOCK_TRANSMUTER_MOBILE_CPU = [
+ transmute_Conv3dPwBnAct,
+ transmute_Conv3d3x3x3DwBnAct,
+ transmute_Conv3dTemporalKernel1BnAct,
+ transmute_Conv3d3x1x1BnAct,
+ transmute_Conv3d5x1x1BnAct,
+]
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/utils/__init__.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/utils/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b36f54e2430fba23f5979a45d4a99fe86b102a2
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py
@@ -0,0 +1,125 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from copy import deepcopy
+from typing import Dict, List
+
+import torch
+import torch.nn as nn
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+
+
+def _add_input_tensor_size_lut_hook(
+ module: nn.Module,
+ input_tensor_size_lut: Dict,
+ hook_handle_list: List,
+ base_name: str = "",
+) -> None:
+ """
+ This helper function recursively goes through all modules in a network, registers
+ forward hook function to each module. The hook function records the input tensor
+ size in forward in input_tensor_size_lut[base_name].
+ Args:
+ module (nn.Module): input module to add hook recursively.
+ input_tensor_size_lut (dict): lut to record input tensor size for hook function.
+ hook_handle_list (list): a list to contain hook handles.
+ base_name (str): name for module input.
+ """
+
+ def hook_fn(_, _in, _out):
+ if isinstance(_in[0], torch.Tensor):
+ input_tensor_size_lut[base_name] = tuple(_in[0].size())
+ return
+
+ handle = module.register_forward_hook(hook_fn)
+ hook_handle_list.append(handle)
+ for name, child in module.named_children():
+ _add_input_tensor_size_lut_hook(
+ child,
+ input_tensor_size_lut,
+ hook_handle_list,
+ base_name=f"{base_name}.{name}",
+ )
+
+
+def _convert_module(
+ module: nn.Module,
+ input_tensor_size_lut: Dict,
+ base_name: str = "",
+ convert_for_quantize: bool = False,
+ native_conv3d_op_qnnpack: bool = False,
+) -> None:
+ """
+ This helper function recursively goes through sub-modules in a network. If current
+ module is a efficient block (instance of EfficientBlockBase) with convert() method,
+ its convert() method will be called, and the input tensor size (needed by efficient
+ blocks for mobile cpu) will be provided by matching module name in
+ input_tensor_size_lut.
+ Otherwise if the input module is a non efficient block, this function will try to go
+ through child modules of input module to look for any efficient block in lower
+ hierarchy.
+ Args:
+ module (nn.Module): input module for convert.
+ input_tensor_size_lut (dict): input tensor size look-up table.
+ base_name (str): module name for input module.
+ convert_for_quantize (bool): whether this module is intended to be quantized.
+ native_conv3d_op_qnnpack (bool): whether the QNNPACK version has native int8
+ Conv3d.
+ """
+ if isinstance(module, EfficientBlockBase):
+ module.convert(
+ input_tensor_size_lut[base_name],
+ convert_for_quantize=convert_for_quantize,
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ else:
+ for name, child in module.named_children():
+ _convert_module(
+ child,
+ input_tensor_size_lut,
+ base_name=f"{base_name}.{name}",
+ convert_for_quantize=convert_for_quantize,
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+
+
+def convert_to_deployable_form(
+ model: nn.Module,
+ input_tensor: torch.Tensor,
+ convert_for_quantize: bool = False,
+ native_conv3d_op_qnnpack: bool = False,
+) -> nn.Module:
+ """
+ This function takes an input model, and returns a deployable model copy.
+ Args:
+ model (nn.Module): input model for conversion. The model can include a mix of
+ efficient blocks (instances of EfficientBlockBase) and non efficient blocks.
+ The efficient blocks will be converted by calling its convert() method, while
+ other blocks will stay unchanged.
+ input_tensor (torch.Tensor): input tensor for model. Note current conversion for
+ deployable form in mobile cpu only works for single input tensor size (i.e.,
+ the future input tensor to converted model should have the same size as
+ input_tensor specified here).
+ convert_for_quantize (bool): whether this module is intended to be quantized.
+ native_conv3d_op_qnnpack (bool): whether the QNNPACK version has native int8
+ Conv3d.
+ """
+ input_tensor_size_lut = {}
+ hook_handle_list = []
+ _add_input_tensor_size_lut_hook(model, input_tensor_size_lut, hook_handle_list)
+ # Run forward to fill in input tensor lut.
+ model.eval()
+ model(input_tensor)
+ # Remove forward hooks.
+ for handle in hook_handle_list:
+ handle.remove()
+ model_converted = deepcopy(model)
+ model_converted.eval()
+ _convert_module(
+ model_converted,
+ input_tensor_size_lut,
+ convert_for_quantize=convert_for_quantize,
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ return model_converted
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/__init__.py b/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py b/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1040218d67f9572aaf517434f6ef6b6ac62564d1
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py
@@ -0,0 +1,35 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from abc import abstractmethod
+
+import torch.nn as nn
+
+
+class EfficientBlockBase(nn.Module):
+ """
+ PyTorchVideo/accelerator provides a set of efficient blocks
+ that have optimal efficiency for each target hardware device.
+
+ Each efficient block has two forms:
+ - original form: this form is for training. When efficient block is instantiated,
+ it is in this original form.
+ - deployable form: this form is for deployment. Once the network is ready for
+ deploy, it can be converted into deployable form for efficient execution
+ on target hardware. One block is transformed into deployable form by calling
+ convert() method. By conversion to deployable form,
+ various optimization (operator fuse, kernel optimization, etc.) are applied.
+
+ EfficientBlockBase is the base class for efficient blocks.
+ All efficient blocks should inherit this base class
+ and implement following methods:
+ - forward(): same as required by nn.Module
+ - convert(): called to convert block into deployable form
+ """
+
+ @abstractmethod
+ def convert(self):
+ pass
+
+ @abstractmethod
+ def forward(self):
+ pass
diff --git a/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py b/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..81ce0aa5716b2477da24ed2bc478079c9cd866fc
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py
@@ -0,0 +1,26 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import torch.nn as nn
+
+from .efficient_block_base import EfficientBlockBase
+
+
+class NoOpConvertBlock(EfficientBlockBase):
+ """
+ This class provides an interface with EfficientBlockBase for modules that do not
+ need convert.
+ Args:
+ model (nn.Module): NoOpConvertBlock takes model as input and generate a wrapper
+ instance of EfficientBlockBase with same functionality as model, with no change
+ applied when convert() is called.
+ """
+
+ def __init__(self, model: nn.Module):
+ super().__init__()
+ self.model = model
+
+ def convert(self, *args, **kwargs):
+ pass
+
+ def forward(self, x):
+ return self.model(x)
diff --git a/code/pytorchvideo/pytorchvideo/data/__init__.py b/code/pytorchvideo/pytorchvideo/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7316dc8b01ddd51b108c54220339bf5221fa5f0
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .ava import Ava # noqa
+from .charades import Charades # noqa
+from .clip_sampling import ( # noqa; noqa
+ ClipSampler,
+ make_clip_sampler,
+ RandomClipSampler,
+ UniformClipSampler,
+)
+from .domsev import DomsevFrameDataset, DomsevVideoDataset # noqa
+from .epic_kitchen_forecasting import EpicKitchenForecasting # noqa
+from .epic_kitchen_recognition import EpicKitchenRecognition # noqa
+from .hmdb51 import Hmdb51 # noqa
+from .kinetics import Kinetics # noqa
+from .labeled_video_dataset import labeled_video_dataset, LabeledVideoDataset # noqa
+from .ssv2 import SSv2
+from .ucf101 import Ucf101 # noqa
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/__init__.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6aa5aa1026d947dc6858b376a5514213f906af5
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/ava.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/ava.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b7e6b5bd6f080a5459904a7ae1431f56681f630
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/ava.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/charades.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/charades.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98832f40b6627be40e83f80b0577fb6304c23b02
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/charades.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/clip_sampling.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/clip_sampling.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85ce69deefb30ba9ebe10f9fcb7db71c488e93a6
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/clip_sampling.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/dataset_manifest_utils.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/dataset_manifest_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f457a018714bb4a95b7fb3f3cfcce681c3824d1b
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/dataset_manifest_utils.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/decoder.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bdb2920463a92bd3093c71af65899de0be9b9d6
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/decoder.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/domsev.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/domsev.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dc32375eede9f63000f92b9ac522c2fa198939d
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/domsev.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/encoded_video.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/encoded_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88395c098864882657e0c3507bfe28f47548d21d
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/encoded_video.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/encoded_video_decord.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/encoded_video_decord.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f965031466e934a88ad5931821e1d6d8168a812
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/encoded_video_decord.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/epic_kitchen_forecasting.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/epic_kitchen_forecasting.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21a07d3782794d4f294658befb2805dc1f0fe66e
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/epic_kitchen_forecasting.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/epic_kitchen_recognition.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/epic_kitchen_recognition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27c8a6a461229f14eff6a8351e114303bf905ebd
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/epic_kitchen_recognition.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/frame_video.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/frame_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16d051e392b987a876d43f1211ec4cf148758123
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/frame_video.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/hmdb51.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/hmdb51.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40a6e745bb7f2539e961daffd9c97b20e4d57aab
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/hmdb51.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/kinetics.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/kinetics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa767341901041ce1577fddb0efe32a350fc69f1
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/kinetics.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/labeled_video_dataset.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/labeled_video_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..186058f1f4ff94abc6327a56c5e60361384d0eac
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/labeled_video_dataset.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/labeled_video_paths.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/labeled_video_paths.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5c62a8a229b969d9e6307683e25e2a186ebf68e
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/labeled_video_paths.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/ssv2.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/ssv2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b2c260c21e87a6253390c54a1d29226ed3dc194
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/ssv2.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/ucf101.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/ucf101.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b062e0a473ecbc61ac592fe0c00598bf28ddad0e
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/ucf101.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/utils.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2cffa33eebc56e538c0cd3121cb4e62e449f8fc
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/utils.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/__pycache__/video.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/__pycache__/video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94ef0f8ad414092f05d6f2a56c7904c6fdf56d29
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/__pycache__/video.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/ava.py b/code/pytorchvideo/pytorchvideo/data/ava.py
new file mode 100644
index 0000000000000000000000000000000000000000..aed7c5e6c748bf57d3a253ec07f0887f373f5888
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/ava.py
@@ -0,0 +1,375 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from __future__ import annotations
+
+import os
+from collections import defaultdict
+from typing import Any, Callable, Dict, Optional, Set, Tuple, Type
+
+import torch
+from iopath.common.file_io import g_pathmgr
+from pytorchvideo.data.clip_sampling import ClipInfo, ClipSampler
+from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset
+
+
+class AvaLabeledVideoFramePaths:
+ """
+ Pre-processor for Ava Actions Dataset stored as image frames -
+ `_`
+ This class handles the parsing of all the necessary
+ csv files containing frame paths and frame labels.
+ """
+
+ # Range of valid annotated frames in Ava dataset
+ AVA_VALID_FRAMES = list(range(902, 1799))
+ FPS = 30
+ AVA_VIDEO_START_SEC = 900
+
+ @classmethod
+ def _aggregate_bboxes_labels(cls, inp: Dict):
+
+ # Needed for aggregating the bounding boxes
+ labels = inp["labels"]
+ extra_info = inp["extra_info"]
+ boxes = inp["boxes"]
+
+ labels_agg = []
+ extra_info_agg = []
+ boxes_agg = []
+ bb_dict = {}
+
+ for i in range(len(labels)):
+ box_label, box_extra_info = labels[i], extra_info[i]
+
+ bbox_key = "{:.2f},{:.2f},{:.2f},{:.2f}".format(
+ boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]
+ )
+
+ if bbox_key not in bb_dict:
+ bb_dict[bbox_key] = len(boxes_agg)
+ boxes_agg.append(boxes[i])
+ labels_agg.append([])
+ extra_info_agg.append([])
+
+ idx = bb_dict[bbox_key]
+ labels_agg[idx].append(box_label)
+ extra_info_agg[idx].append(box_extra_info)
+
+ return {
+ "labels": labels_agg,
+ "boxes": boxes_agg,
+ "extra_info": extra_info_agg,
+ }
+
+ @classmethod
+ def from_csv(
+ cls,
+ frame_paths_file: str,
+ frame_labels_file: str,
+ video_path_prefix: str,
+ label_map_file: Optional[str] = None,
+ ) -> AvaLabeledVideoFramePaths:
+ """
+ Args:
+ frame_labels_file (str): Path to the file containing containing labels
+ per key frame. Acceptible file formats are,
+ Type 1:
+
+ Type 2:
+
+ frame_paths_file (str): Path to a file containing relative paths
+ to all the frames in the video. Each line in the file is of the
+ form
+ video_path_prefix (str): Path to be augumented to the each relative frame
+ path to get the global frame path.
+ label_map_file (str): Path to a .pbtxt containing class id's and class names.
+ If not set, label_map is not loaded and bbox labels are not pruned
+ based on allowable class_id's in label_map.
+ Returs:
+ A list of tuples of the the form (video_frames directory, label dictionary).
+ """
+ if label_map_file is not None:
+ _, allowed_class_ids = AvaLabeledVideoFramePaths.read_label_map(
+ label_map_file
+ )
+ else:
+ allowed_class_ids = None
+
+ (
+ image_paths,
+ video_idx_to_name,
+ video_name_to_idx,
+ ) = AvaLabeledVideoFramePaths.load_image_lists(
+ frame_paths_file, video_path_prefix
+ )
+
+ video_frame_labels = AvaLabeledVideoFramePaths.load_and_parse_labels_csv(
+ frame_labels_file,
+ video_name_to_idx,
+ allowed_class_ids,
+ )
+
+ # Populate keyframes list
+ labeled_video_paths = []
+ for video_id in video_frame_labels.keys():
+ for frame_video_sec in video_frame_labels[video_id].keys():
+ labels = video_frame_labels[video_id][frame_video_sec]
+ if len(labels["labels"]) > 0:
+ labels = AvaLabeledVideoFramePaths._aggregate_bboxes_labels(labels)
+ labels["video_index"] = video_id
+ labels["clip_index"] = frame_video_sec
+ video_frames_dir = os.path.dirname(image_paths[video_id][0])
+ labeled_video_paths.append((video_frames_dir, labels))
+
+ return labeled_video_paths
+
+ @staticmethod
+ def load_and_parse_labels_csv(
+ frame_labels_file: str,
+ video_name_to_idx: dict,
+ allowed_class_ids: Optional[Set] = None,
+ ):
+ """
+ Parses AVA per frame labels .csv file.
+ Args:
+ frame_labels_file (str): Path to the file containing labels
+ per key frame. Acceptible file formats are,
+ Type 1:
+
+ Type 2:
+
+ video_name_to_idx (dict): Dictionary mapping video names to indices.
+ allowed_class_ids (set): A set of integer unique class (bbox label)
+ id's that are allowed in the dataset. If not set, all class id's
+ are allowed in the bbox labels.
+ Returns:
+ (dict): A dictionary of dictionary containing labels per each keyframe
+ in each video. Here, the label for each keyframe is again a dict
+ of the form,
+ {
+ 'labels': a list of bounding boxes
+ 'boxes':a list of action lables for the bounding box
+ 'extra_info': ist of extra information cotaining either
+ detections iou's or person id's depending on the
+ csv format.
+ }
+ """
+ labels_dict = {}
+ with g_pathmgr.open(frame_labels_file, "r") as f:
+ for line in f:
+ row = line.strip().split(",")
+
+ video_name = row[0]
+ video_idx = video_name_to_idx[video_name]
+
+ frame_sec = float(row[1])
+ if (
+ frame_sec > AvaLabeledVideoFramePaths.AVA_VALID_FRAMES[-1]
+ or frame_sec < AvaLabeledVideoFramePaths.AVA_VALID_FRAMES[0]
+ ):
+ continue
+
+ # Since frame labels in video start from 0 not at 900 secs
+ frame_sec = frame_sec - AvaLabeledVideoFramePaths.AVA_VIDEO_START_SEC
+
+ # Box with format [x1, y1, x2, y2] with a range of [0, 1] as float.
+ bbox = list(map(float, row[2:6]))
+
+ # Label
+ label = -1 if row[6] == "" else int(row[6])
+ # Continue if the current label is not in allowed labels.
+ if (allowed_class_ids is not None) and (label not in allowed_class_ids):
+ continue
+
+ # Both id's and iou's are treated as float
+ extra_info = float(row[7])
+
+ if video_idx not in labels_dict:
+ labels_dict[video_idx] = {}
+
+ if frame_sec not in labels_dict[video_idx]:
+ labels_dict[video_idx][frame_sec] = defaultdict(list)
+
+ labels_dict[video_idx][frame_sec]["boxes"].append(bbox)
+ labels_dict[video_idx][frame_sec]["labels"].append(label)
+ labels_dict[video_idx][frame_sec]["extra_info"].append(extra_info)
+ return labels_dict
+
+ @staticmethod
+ def load_image_lists(frame_paths_file: str, video_path_prefix: str) -> Tuple:
+ """
+ Loading image paths from the corresponding file.
+ Args:
+ frame_paths_file (str): Path to a file containing relative paths
+ to all the frames in the video. Each line in the file is of the
+ form
+ video_path_prefix (str): Path to be augumented to the each relative
+ frame path to get the global frame path.
+ Returns:
+ (tuple): A tuple of the following,
+ image_paths_list: List of list containing absolute frame paths.
+ Wherein the outer list is per video and inner list is per
+ timestamp.
+ video_idx_to_name: A dictionary mapping video index to name
+ video_name_to_idx: A dictionary maoping video name to index
+ """
+
+ image_paths = []
+ video_name_to_idx = {}
+ video_idx_to_name = []
+
+ with g_pathmgr.open(frame_paths_file, "r") as f:
+ f.readline()
+ for line in f:
+ row = line.split()
+ # The format of each row should follow:
+ # original_vido_id video_id frame_id path labels.
+ assert len(row) == 5
+ video_name = row[0]
+
+ if video_name not in video_name_to_idx:
+ idx = len(video_name_to_idx)
+ video_name_to_idx[video_name] = idx
+ video_idx_to_name.append(video_name)
+ image_paths.append({})
+
+ data_key = video_name_to_idx[video_name]
+ frame_id = int(row[2])
+ image_paths[data_key][frame_id] = os.path.join(
+ video_path_prefix, row[3]
+ )
+
+ image_paths_list = []
+ for i in range(len(image_paths)):
+ image_paths_list.append([])
+ sorted_keys = sorted(image_paths[i])
+ for key in sorted_keys:
+ image_paths_list[i].append(image_paths[i][key])
+
+ return image_paths_list, video_idx_to_name, video_name_to_idx
+
+ @staticmethod
+ def read_label_map(label_map_file: str) -> Tuple:
+ """
+ Read label map and class ids.
+ Args:
+ label_map_file (str): Path to a .pbtxt containing class id's
+ and class names
+ Returns:
+ (tuple): A tuple of the following,
+ label_map (dict): A dictionary mapping class id to
+ the associated class names.
+ class_ids (set): A set of integer unique class id's
+ """
+ label_map = {}
+ class_ids = set()
+ name = ""
+ class_id = ""
+ with g_pathmgr.open(label_map_file, "r") as f:
+ for line in f:
+ if line.startswith(" name:"):
+ name = line.split('"')[1]
+ elif line.startswith(" id:") or line.startswith(" label_id:"):
+ class_id = int(line.strip().split(" ")[-1])
+ label_map[class_id] = name
+ class_ids.add(class_id)
+ return label_map, class_ids
+
+
+class TimeStampClipSampler:
+ """
+ A sepcialized clip sampler for sampling video clips around specific
+ timestamps. This is particularly used in datasets like Ava wherein only
+ a specific subset of clips in the video have annotations
+ """
+
+ def __init__(self, clip_sampler: ClipSampler) -> None:
+ """
+ Args:
+ clip_sampler (`pytorchvideo.data.ClipSampler`): Strategy used for sampling
+ between the untrimmed clip boundary.
+ """
+ self.clip_sampler = clip_sampler
+
+ def __call__(
+ self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
+ ) -> ClipInfo:
+ """
+ Args:
+ last_clip_time (float): Not used for TimeStampClipSampler.
+ video_duration: (float): Not used for TimeStampClipSampler.
+ annotation (Dict): Dict containing time step to sample aroud.
+ Returns:
+ clip_info (ClipInfo): includes the clip information of (clip_start_time,
+ clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds.
+ clip_index, aux_index and is_last_clip are always 0, 0 and True, respectively.
+ """
+ center_frame_sec = annotation["clip_index"] # a.k.a timestamp
+ clip_start_sec = center_frame_sec - self.clip_sampler._clip_duration / 2.0
+ return ClipInfo(
+ clip_start_sec,
+ clip_start_sec + self.clip_sampler._clip_duration,
+ 0,
+ 0,
+ True,
+ )
+
+ def reset(self) -> None:
+ pass
+
+
+def Ava(
+ frame_paths_file: str,
+ frame_labels_file: str,
+ video_path_prefix: str = "",
+ label_map_file: Optional[str] = None,
+ clip_sampler: Callable = ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[dict], Any]] = None,
+) -> None:
+ """
+ Args:
+ frame_paths_file (str): Path to a file containing relative paths
+ to all the frames in the video. Each line in the file is of the
+ form
+ frame_labels_file (str): Path to the file containing containing labels
+ per key frame. Acceptible file formats are,
+ Type 1:
+
+ Type 2:
+
+ video_path_prefix (str): Path to be augumented to the each relative frame
+ path to get the global frame path.
+ label_map_file (str): Path to a .pbtxt containing class id's
+ and class names. If not set, label_map is not loaded and bbox labels are
+ not pruned based on allowable class_id's in label_map.
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video.
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+ transform (Optional[Callable]): This callable is evaluated on the clip output
+ and the corresponding bounding boxes before the clip and the bounding boxes
+ are returned. It can be used for user defined preprocessing and
+ augmentations to the clips. If transform is None, the clip and bounding
+ boxes are returned as it is.
+ """
+ labeled_video_paths = AvaLabeledVideoFramePaths.from_csv(
+ frame_paths_file,
+ frame_labels_file,
+ video_path_prefix,
+ label_map_file,
+ )
+ return LabeledVideoDataset(
+ labeled_video_paths=labeled_video_paths,
+ clip_sampler=TimeStampClipSampler(clip_sampler),
+ transform=transform,
+ video_sampler=video_sampler,
+ decode_audio=False,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/data/charades.py b/code/pytorchvideo/pytorchvideo/data/charades.py
new file mode 100644
index 0000000000000000000000000000000000000000..c211a6131737efc17760dde101a5dc1b1504b8ed
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/charades.py
@@ -0,0 +1,220 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import csv
+import functools
+import itertools
+import os
+from collections import defaultdict
+from typing import Any, Callable, List, Optional, Tuple, Type
+
+import torch
+import torch.utils.data
+from iopath.common.file_io import g_pathmgr
+from pytorchvideo.data.clip_sampling import ClipSampler
+from pytorchvideo.data.frame_video import FrameVideo
+
+from .utils import MultiProcessSampler
+
+
+class Charades(torch.utils.data.IterableDataset):
+ """
+ Action recognition video dataset for
+ `Charades `_ stored as image frames.
+
+ This dataset handles the parsing of frames, loading and clip sampling for the
+ videos. All io is done through :code:`iopath.common.file_io.PathManager`, enabling
+ non-local storage uri's to be used.
+ """
+
+ # Number of classes represented by this dataset's annotated labels.
+ NUM_CLASSES = 157
+
+ def __init__(
+ self,
+ data_path: str,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[dict], Any]] = None,
+ video_path_prefix: str = "",
+ frames_per_clip: Optional[int] = None,
+ ) -> None:
+ """
+ Args:
+ data_path (str): Path to the data file. This file must be a space
+ separated csv with the format: (original_vido_id video_id frame_id
+ path_labels)
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Optional[Callable]): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations on the clips. The clip output format is described in __next__().
+
+ video_path_prefix (str): prefix path to add to all paths from data_path.
+
+ frames_per_clip (Optional[int]): The number of frames per clip to sample.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.Charades.__init__")
+
+ self._transform = transform
+ self._clip_sampler = clip_sampler
+ (
+ self._path_to_videos,
+ self._labels,
+ self._video_labels,
+ ) = _read_video_paths_and_labels(data_path, prefix=video_path_prefix)
+ self._video_sampler = video_sampler(self._path_to_videos)
+ self._video_sampler_iter = None # Initialized on first call to self.__next__()
+ self._frame_filter = (
+ functools.partial(
+ Charades._sample_clip_frames,
+ frames_per_clip=frames_per_clip,
+ )
+ if frames_per_clip is not None
+ else None
+ )
+
+ # Depending on the clip sampler type, we may want to sample multiple clips
+ # from one video. In that case, we keep the store video, label and previous sampled
+ # clip time in these variables.
+ self._loaded_video = None
+ self._loaded_clip = None
+ self._next_clip_start_time = 0.0
+
+ @staticmethod
+ def _sample_clip_frames(
+ frame_indices: List[int], frames_per_clip: int
+ ) -> List[int]:
+ """
+ Args:
+ frame_indices (list): list of frame indices.
+ frames_per+clip (int): The number of frames per clip to sample.
+
+ Returns:
+ (list): Outputs a subsampled list with num_samples frames.
+ """
+ num_frames = len(frame_indices)
+ indices = torch.linspace(0, num_frames - 1, frames_per_clip)
+ indices = torch.clamp(indices, 0, num_frames - 1).long()
+
+ return [frame_indices[idx] for idx in indices]
+
+ @property
+ def video_sampler(self) -> torch.utils.data.Sampler:
+ return self._video_sampler
+
+ def __next__(self) -> dict:
+ """
+ Retrieves the next clip based on the clip sampling strategy and video sampler.
+
+ Returns:
+ A dictionary with the following format.
+
+ .. code-block:: text
+
+ {
+ 'video': ,
+ 'label': ,
+ 'video_label':
+ 'video_index': ,
+ 'clip_index': ,
+ 'aug_index': ,
+ }
+ """
+ if not self._video_sampler_iter:
+ # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned.
+ self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler))
+
+ if self._loaded_video:
+ video, video_index = self._loaded_video
+ else:
+ video_index = next(self._video_sampler_iter)
+ path_to_video_frames = self._path_to_videos[video_index]
+ video = FrameVideo.from_frame_paths(path_to_video_frames)
+ self._loaded_video = (video, video_index)
+
+ clip_start, clip_end, clip_index, aug_index, is_last_clip = self._clip_sampler(
+ self._next_clip_start_time, video.duration, {}
+ )
+ # Only load the clip once and reuse previously stored clip if there are multiple
+ # views for augmentations to perform on the same clip.
+ if aug_index == 0:
+ self._loaded_clip = video.get_clip(clip_start, clip_end, self._frame_filter)
+
+ frames, frame_indices = (
+ self._loaded_clip["video"],
+ self._loaded_clip["frame_indices"],
+ )
+ self._next_clip_start_time = clip_end
+
+ if is_last_clip:
+ self._loaded_video = None
+ self._next_clip_start_time = 0.0
+
+ # Merge unique labels from each frame into clip label.
+ labels_by_frame = [
+ self._labels[video_index][i]
+ for i in range(min(frame_indices), max(frame_indices) + 1)
+ ]
+ sample_dict = {
+ "video": frames,
+ "label": labels_by_frame,
+ "video_label": self._video_labels[video_index],
+ "video_name": str(video_index),
+ "video_index": video_index,
+ "clip_index": clip_index,
+ "aug_index": aug_index,
+ }
+ if self._transform is not None:
+ sample_dict = self._transform(sample_dict)
+
+ return sample_dict
+
+ def __iter__(self):
+ return self
+
+
+def _read_video_paths_and_labels(
+ video_path_label_file: List[str], prefix: str = ""
+) -> Tuple[List[str], List[int]]:
+ """
+ Args:
+ video_path_label_file (List[str]): a file that contains frame paths for each
+ video and the corresponding frame label. The file must be a space separated
+ csv of the format:
+ `original_vido_id video_id frame_id path labels`
+
+ prefix (str): prefix path to add to all paths from video_path_label_file.
+
+ """
+ image_paths = defaultdict(list)
+ labels = defaultdict(list)
+ with g_pathmgr.open(video_path_label_file, "r") as f:
+
+ # Space separated CSV with format: original_vido_id video_id frame_id path labels
+ csv_reader = csv.DictReader(f, delimiter=" ")
+ for row in csv_reader:
+ assert len(row) == 5
+ video_name = row["original_vido_id"]
+ path = os.path.join(prefix, row["path"])
+ image_paths[video_name].append(path)
+ frame_labels = row["labels"].replace('"', "")
+ label_list = []
+ if frame_labels:
+ label_list = [int(x) for x in frame_labels.split(",")]
+
+ labels[video_name].append(label_list)
+
+ # Extract image paths from dictionary and return paths and labels as list.
+ video_names = image_paths.keys()
+ image_paths = [image_paths[key] for key in video_names]
+ labels = [labels[key] for key in video_names]
+ # Aggregate labels from all frames to form video-level labels.
+ video_labels = [list(set(itertools.chain(*label_list))) for label_list in labels]
+ return image_paths, labels, video_labels
diff --git a/code/pytorchvideo/pytorchvideo/data/clip_sampling.py b/code/pytorchvideo/pytorchvideo/data/clip_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f59c5c1e3f47feb6c339a06d48db4595fcd02618
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/clip_sampling.py
@@ -0,0 +1,413 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import random
+from abc import ABC, abstractmethod
+from fractions import Fraction
+from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
+
+
+class ClipInfo(NamedTuple):
+ """
+ Named-tuple for clip information with:
+ clip_start_sec (Union[float, Fraction]): clip start time.
+ clip_end_sec (Union[float, Fraction]): clip end time.
+ clip_index (int): clip index in the video.
+ aug_index (int): augmentation index for the clip. Different augmentation methods
+ might generate multiple views for the same clip.
+ is_last_clip (bool): a bool specifying whether there are more clips to be
+ sampled from the video.
+ """
+
+ clip_start_sec: Union[float, Fraction]
+ clip_end_sec: Union[float, Fraction]
+ clip_index: int
+ aug_index: int
+ is_last_clip: bool
+
+
+class ClipInfoList(NamedTuple):
+ """
+ Named-tuple for clip information with:
+ clip_start_sec (float): clip start time.
+ clip_end_sec (float): clip end time.
+ clip_index (int): clip index in the video.
+ aug_index (int): augmentation index for the clip. Different augmentation methods
+ might generate multiple views for the same clip.
+ is_last_clip (bool): a bool specifying whether there are more clips to be
+ sampled from the video.
+ """
+
+ clip_start_sec: List[float]
+ clip_end_sec: List[float]
+ clip_index: List[float]
+ aug_index: List[float]
+ is_last_clip: List[float]
+
+
+class ClipSampler(ABC):
+ """
+ Interface for clip samplers that take a video time, previous sampled clip time,
+ and returns a named-tuple ``ClipInfo``.
+ """
+
+ def __init__(self, clip_duration: Union[float, Fraction]) -> None:
+ self._clip_duration = Fraction(clip_duration)
+ self._current_clip_index = 0
+ self._current_aug_index = 0
+
+ @abstractmethod
+ def __call__(
+ self,
+ last_clip_end_time: Union[float, Fraction],
+ video_duration: Union[float, Fraction],
+ annotation: Dict[str, Any],
+ ) -> ClipInfo:
+ pass
+
+ def reset(self) -> None:
+ """Resets any video-specific attributes in preperation for next video"""
+ pass
+
+
+def make_clip_sampler(sampling_type: str, *args) -> ClipSampler:
+ """
+ Constructs the clip samplers found in ``pytorchvideo.data.clip_sampling`` from the
+ given arguments.
+
+ Args:
+ sampling_type (str): choose clip sampler to return. It has three options:
+
+ * uniform: constructs and return ``UniformClipSampler``
+ * random: construct and return ``RandomClipSampler``
+ * constant_clips_per_video: construct and return ``ConstantClipsPerVideoSampler``
+
+ *args: the args to pass to the chosen clip sampler constructor.
+ """
+ if sampling_type == "uniform":
+ return UniformClipSampler(*args)
+ elif sampling_type == "random":
+ return RandomClipSampler(*args)
+ elif sampling_type == "constant_clips_per_video":
+ return ConstantClipsPerVideoSampler(*args)
+ elif sampling_type == "random_multi":
+ return RandomMultiClipSampler(*args)
+ else:
+ raise NotImplementedError(f"{sampling_type} not supported")
+
+
+class UniformClipSampler(ClipSampler):
+ """
+ Evenly splits the video into clips of size clip_duration.
+ """
+
+ def __init__(
+ self,
+ clip_duration: Union[float, Fraction],
+ stride: Optional[Union[float, Fraction]] = None,
+ backpad_last: bool = False,
+ eps: float = 1e-6,
+ ):
+ """
+ Args:
+ clip_duration (Union[float, Fraction]):
+ The length of the clip to sample (in seconds).
+ stride (Union[float, Fraction], optional):
+ The amount of seconds to offset the next clip by
+ default value of None is equivalent to no stride => stride == clip_duration.
+ eps (float):
+ Epsilon for floating point comparisons. Used to check the last clip.
+ backpad_last (bool):
+ Whether to include the last frame(s) by "back padding".
+
+ For instance, if we have a video of 39 frames (30 fps = 1.3s)
+ with a stride of 16 (0.533s) with a clip duration of 32 frames
+ (1.0667s). The clips will be (in frame numbers):
+
+ with backpad_last = False
+ - [0, 31]
+
+ with backpad_last = True
+ - [0, 31]
+ - [8, 39], this is "back-padded" from [16, 48] to fit the last window
+ Note that you can use Fraction for clip_duration and stride if you want to
+ avoid float precision issue and need accurate frames in each clip.
+ """
+ super().__init__(clip_duration)
+ self._stride = stride if stride is not None else self._clip_duration
+ self._eps = eps
+ self._backpad_last = backpad_last
+
+ assert self._stride > 0, "stride must be positive"
+
+ def _clip_start_end(
+ self,
+ last_clip_end_time: Union[float, Fraction],
+ video_duration: Union[float, Fraction],
+ backpad_last: bool,
+ ) -> Tuple[Fraction, Fraction]:
+ """
+ Helper to calculate the start/end clip with backpad logic
+ """
+ delta = self._stride - self._clip_duration
+ last_end_time = -delta if last_clip_end_time is None else last_clip_end_time
+ clip_start = Fraction(last_end_time + delta)
+ clip_end = Fraction(clip_start + self._clip_duration)
+ if backpad_last:
+ buffer_amount = max(0, clip_end - video_duration)
+ clip_start -= buffer_amount
+ clip_start = Fraction(max(0, clip_start)) # handle rounding
+ clip_end = Fraction(clip_start + self._clip_duration)
+
+ return clip_start, clip_end
+
+ def __call__(
+ self,
+ last_clip_end_time: Optional[float],
+ video_duration: float,
+ annotation: Dict[str, Any],
+ ) -> ClipInfo:
+ """
+ Args:
+ last_clip_end_time (float): the last clip end time sampled from this video. This
+ should be 0.0 if the video hasn't had clips sampled yet.
+ video_duration: (float): the duration of the video that's being sampled in seconds
+ annotation (Dict): Not used by this sampler.
+ Returns:
+ clip_info: (ClipInfo): includes the clip information (clip_start_time,
+ clip_end_time, clip_index, aug_index, is_last_clip), where the times are in
+ seconds and is_last_clip is False when there is still more of time in the video
+ to be sampled.
+ """
+ clip_start, clip_end = self._clip_start_end(
+ last_clip_end_time, video_duration, backpad_last=self._backpad_last
+ )
+
+ # if they both end at the same time - it's the last clip
+ _, next_clip_end = self._clip_start_end(
+ clip_end, video_duration, backpad_last=self._backpad_last
+ )
+ if self._backpad_last:
+ is_last_clip = abs(next_clip_end - clip_end) < self._eps
+ else:
+ is_last_clip = (next_clip_end - video_duration) > self._eps
+
+ clip_index = self._current_clip_index
+ self._current_clip_index += 1
+
+ if is_last_clip:
+ self.reset()
+
+ return ClipInfo(clip_start, clip_end, clip_index, 0, is_last_clip)
+
+ def reset(self):
+ self._current_clip_index = 0
+
+
+class UniformClipSamplerTruncateFromStart(UniformClipSampler):
+ """
+ Evenly splits the video into clips of size clip_duration.
+ If truncation_duration is set, clips sampled from [0, truncation_duration].
+ If truncation_duration is not set, defaults to UniformClipSampler.
+ """
+
+ def __init__(
+ self,
+ clip_duration: Union[float, Fraction],
+ stride: Optional[Union[float, Fraction]] = None,
+ backpad_last: bool = False,
+ eps: float = 1e-6,
+ truncation_duration: float = None,
+ ) -> None:
+ super().__init__(clip_duration, stride, backpad_last, eps)
+ self.truncation_duration = truncation_duration
+
+ def __call__(
+ self,
+ last_clip_end_time: float,
+ video_duration: float,
+ annotation: Dict[str, Any],
+ ) -> ClipInfo:
+
+ truncated_video_duration = video_duration
+ if self.truncation_duration is not None:
+ truncated_video_duration = min(self.truncation_duration, video_duration)
+
+ return super().__call__(
+ last_clip_end_time, truncated_video_duration, annotation
+ )
+
+
+class RandomClipSampler(ClipSampler):
+ """
+ Randomly samples clip of size clip_duration from the videos.
+ """
+
+ def __call__(
+ self,
+ last_clip_end_time: float,
+ video_duration: float,
+ annotation: Dict[str, Any],
+ ) -> ClipInfo:
+ """
+ Args:
+ last_clip_end_time (float): Not used for RandomClipSampler.
+ video_duration: (float): the duration (in seconds) for the video that's
+ being sampled
+ annotation (Dict): Not used by this sampler.
+ Returns:
+ clip_info (ClipInfo): includes the clip information of (clip_start_time,
+ clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds.
+ clip_index, aux_index and is_last_clip are always 0, 0 and True, respectively.
+
+ """
+ max_possible_clip_start = max(video_duration - self._clip_duration, 0)
+ clip_start_sec = Fraction(random.uniform(0, max_possible_clip_start))
+ return ClipInfo(
+ clip_start_sec, clip_start_sec + self._clip_duration, 0, 0, True
+ )
+
+
+class RandomMultiClipSampler(RandomClipSampler):
+ """
+ Randomly samples multiple clips of size clip_duration from the videos.
+ """
+
+ def __init__(self, clip_duration: float, num_clips: int) -> None:
+ super().__init__(clip_duration)
+ self._num_clips = num_clips
+
+ def __call__(
+ self,
+ last_clip_end_time: Optional[float],
+ video_duration: float,
+ annotation: Dict[str, Any],
+ ) -> ClipInfoList:
+
+ (
+ clip_start_list,
+ clip_end_list,
+ clip_index_list,
+ aug_index_list,
+ is_last_clip_list,
+ ) = (
+ self._num_clips * [None],
+ self._num_clips * [None],
+ self._num_clips * [None],
+ self._num_clips * [None],
+ self._num_clips * [None],
+ )
+ for i in range(self._num_clips):
+ (
+ clip_start_list[i],
+ clip_end_list[i],
+ clip_index_list[i],
+ aug_index_list[i],
+ is_last_clip_list[i],
+ ) = super().__call__(last_clip_end_time, video_duration, annotation)
+
+ return ClipInfoList(
+ clip_start_list,
+ clip_end_list,
+ clip_index_list,
+ aug_index_list,
+ is_last_clip_list,
+ )
+
+
+class RandomMultiClipSamplerTruncateFromStart(RandomMultiClipSampler):
+ """
+ Randomly samples multiple clips of size clip_duration from the videos.
+ If truncation_duration is set, clips sampled from [0, truncation_duration].
+ If truncation_duration is not set, defaults to RandomMultiClipSampler.
+ """
+
+ def __init__(
+ self, clip_duration: float, num_clips: int, truncation_duration: float = None
+ ) -> None:
+ super().__init__(clip_duration, num_clips)
+ self.truncation_duration = truncation_duration
+
+ def __call__(
+ self,
+ last_clip_end_time: Optional[float],
+ video_duration: float,
+ annotation: Dict[str, Any],
+ ) -> ClipInfoList:
+
+ truncated_video_duration = video_duration
+ if self.truncation_duration is not None:
+ truncated_video_duration = min(self.truncation_duration, video_duration)
+
+ return super().__call__(
+ last_clip_end_time, truncated_video_duration, annotation
+ )
+
+
+class ConstantClipsPerVideoSampler(ClipSampler):
+ """
+ Evenly splits the video into clips_per_video increments and samples clips of size
+ clip_duration at these increments.
+ """
+
+ def __init__(
+ self, clip_duration: float, clips_per_video: int, augs_per_clip: int = 1
+ ) -> None:
+ super().__init__(clip_duration)
+ self._clips_per_video = clips_per_video
+ self._augs_per_clip = augs_per_clip
+
+ def __call__(
+ self,
+ last_clip_end_time: Optional[float],
+ video_duration: float,
+ annotation: Dict[str, Any],
+ ) -> ClipInfo:
+ """
+ Args:
+ last_clip_end_time (float): Not used for ConstantClipsPerVideoSampler.
+ video_duration: (float): the duration (in seconds) for the video that's
+ being sampled.
+ annotation (Dict): Not used by this sampler.
+ Returns:
+ a named-tuple `ClipInfo`: includes the clip information of (clip_start_time,
+ clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds.
+ is_last_clip is True after clips_per_video clips have been sampled or the end
+ of the video is reached.
+
+ """
+ max_possible_clip_start = Fraction(max(video_duration - self._clip_duration, 0))
+ uniform_clip = Fraction(
+ max_possible_clip_start, max(self._clips_per_video - 1, 1)
+ )
+ clip_start_sec = uniform_clip * self._current_clip_index
+ clip_index = self._current_clip_index
+ aug_index = self._current_aug_index
+
+ self._current_aug_index += 1
+ if self._current_aug_index >= self._augs_per_clip:
+ self._current_clip_index += 1
+ self._current_aug_index = 0
+
+ # Last clip is True if sampled self._clips_per_video or if end of video is reached.
+ is_last_clip = False
+ if (
+ self._current_clip_index >= self._clips_per_video
+ or uniform_clip * self._current_clip_index > max_possible_clip_start
+ ):
+ self._current_clip_index = 0
+ is_last_clip = True
+
+ if is_last_clip:
+ self.reset()
+
+ return ClipInfo(
+ clip_start_sec,
+ clip_start_sec + self._clip_duration,
+ clip_index,
+ aug_index,
+ is_last_clip,
+ )
+
+ def reset(self):
+ self._current_clip_index = 0
+ self._current_aug_index = 0
diff --git a/code/pytorchvideo/pytorchvideo/data/dataset_manifest_utils.py b/code/pytorchvideo/pytorchvideo/data/dataset_manifest_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..948dbde6a1efa0fd05a4cd6e60f90e053d2df227
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/dataset_manifest_utils.py
@@ -0,0 +1,315 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import datetime
+import os
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict, Optional, Union
+
+from pytorchvideo.data.encoded_video import EncodedVideo
+from pytorchvideo.data.frame_video import FrameVideo
+from pytorchvideo.data.utils import (
+ DataclassFieldCaster,
+ load_dataclass_dict_from_csv,
+ save_dataclass_objs_to_headered_csv,
+)
+from pytorchvideo.data.video import Video
+
+
+@dataclass
+class EncodedVideoInfo(DataclassFieldCaster):
+ """
+ Class representing the location of an available encoded video.
+ """
+
+ video_id: str
+ file_path: str
+
+
+@dataclass
+class VideoFrameInfo(DataclassFieldCaster):
+ """
+ Class representing the locations of all frames that compose a video.
+ """
+
+ video_id: str
+ location: str
+ frame_file_stem: str
+ frame_string_length: int
+ min_frame_number: int
+ max_frame_number: int
+ file_extension: str
+
+
+@dataclass
+class VideoInfo(DataclassFieldCaster):
+ """
+ Class representing the video-level metadata of a video from an arbitrary video dataset.
+ """
+
+ video_id: str
+ resolution: str
+ duration: float
+ fps: float
+
+
+@dataclass
+class VideoClipInfo(DataclassFieldCaster):
+ video_id: str
+ start_time: float
+ stop_time: float
+
+
+@dataclass
+class ImageFrameInfo(DataclassFieldCaster):
+ """
+ Class representing the metadata (and labels) for a single frame
+ """
+
+ video_id: str
+ frame_id: str
+ frame_number: int
+ frame_file_path: str
+
+
+class VideoDatasetType(Enum):
+ Frame = 1
+ EncodedVideo = 2
+
+
+class ImageDataset:
+ @staticmethod
+ def _load_images(
+ frame_manifest_file_path: Optional[str],
+ video_info_file_path: str,
+ multithreaded_io: bool,
+ ) -> Dict[str, ImageFrameInfo]:
+ video_infos: Dict[str, VideoInfo] = load_dataclass_dict_from_csv(
+ video_info_file_path, VideoInfo, "video_id"
+ )
+ video_frames: Dict[str, VideoFrameInfo] = load_dataclass_dict_from_csv(
+ frame_manifest_file_path, VideoFrameInfo, "video_id"
+ )
+ VideoDataset._remove_video_info_missing_or_incomplete_videos(
+ video_frames, video_infos
+ )
+
+ image_infos = {}
+ for video_id in video_infos:
+ frame_filepaths = VideoDataset._frame_number_to_filepaths(
+ video_id, video_frames, video_infos
+ )
+ video_info = video_infos[video_id]
+ video_frame_info = video_frames[video_info.video_id]
+ for frame_filepath, frame_number in zip(
+ frame_filepaths,
+ range(
+ video_frame_info.min_frame_number, video_frame_info.max_frame_number
+ ),
+ ):
+ frame_id = os.path.splitext(os.path.basename(frame_filepath))[0]
+ image_infos[frame_id] = ImageFrameInfo(
+ video_id, frame_id, frame_number, frame_filepath
+ )
+ return image_infos
+
+
+class VideoDataset:
+ @staticmethod
+ def _load_videos(
+ video_data_manifest_file_path: Optional[str],
+ video_info_file_path: str,
+ multithreaded_io: bool,
+ dataset_type: VideoDatasetType,
+ ) -> Dict[str, Video]:
+ video_infos: Dict[str, VideoInfo] = load_dataclass_dict_from_csv(
+ video_info_file_path, VideoInfo, "video_id"
+ )
+ if dataset_type == VideoDatasetType.Frame:
+ return VideoDataset._load_frame_videos(
+ video_data_manifest_file_path, video_infos, multithreaded_io
+ )
+ elif dataset_type == VideoDatasetType.EncodedVideo:
+ return VideoDataset._load_encoded_videos(
+ video_data_manifest_file_path, video_infos
+ )
+
+ @staticmethod
+ def _load_frame_videos(
+ frame_manifest_file_path: str,
+ video_infos: Dict[str, VideoInfo],
+ multithreaded_io: bool,
+ ):
+ video_frames: Dict[str, VideoFrameInfo] = load_dataclass_dict_from_csv(
+ frame_manifest_file_path, VideoFrameInfo, "video_id"
+ )
+ VideoDataset._remove_video_info_missing_or_incomplete_videos(
+ video_frames, video_infos
+ )
+ return {
+ video_id: FrameVideo(
+ video_frame_paths=VideoDataset._frame_number_to_filepaths(
+ video_id, video_frames, video_infos
+ ),
+ duration=video_infos[video_id].duration,
+ fps=video_infos[video_id].fps,
+ multithreaded_io=multithreaded_io,
+ )
+ for video_id in video_infos
+ }
+
+ @staticmethod
+ def _load_encoded_videos(
+ encoded_video_manifest_file_path: str,
+ video_infos: Dict[str, VideoInfo],
+ ):
+ encoded_video_infos: Dict[str, EncodedVideoInfo] = load_dataclass_dict_from_csv(
+ encoded_video_manifest_file_path, EncodedVideoInfo, "video_id"
+ )
+ VideoDataset._remove_video_info_missing_or_incomplete_videos(
+ encoded_video_infos, video_infos
+ )
+
+ return {
+ video_id: EncodedVideo.from_path(encoded_video_info.file_path)
+ for video_id, encoded_video_info in encoded_video_infos.items()
+ }
+
+ @staticmethod
+ def _frame_number_to_filepaths(
+ video_id: str,
+ video_frames: Dict[str, VideoFrameInfo],
+ video_infos: Dict[str, VideoInfo],
+ ) -> Optional[str]:
+ video_info = video_infos[video_id]
+ video_frame_info = video_frames[video_info.video_id]
+
+ frame_filepaths = []
+ num_frames = (
+ video_frame_info.max_frame_number - video_frame_info.min_frame_number + 1
+ )
+ for frame_index in range(num_frames):
+ frame_number = frame_index + video_frame_info.min_frame_number
+ if (
+ frame_number < video_frame_info.min_frame_number
+ or frame_number > video_frame_info.max_frame_number
+ ):
+ return None
+
+ frame_path_index = str(frame_number)
+ frame_prefix = video_frame_info.frame_file_stem
+ num_zero_pad = (
+ video_frame_info.frame_string_length
+ - len(frame_path_index)
+ - len(frame_prefix)
+ )
+ zero_padding = "0" * num_zero_pad
+ frame_component = (
+ f"{frame_prefix}{zero_padding}{frame_path_index}"
+ f".{video_frame_info.file_extension}"
+ )
+ frame_filepaths.append(f"{video_frame_info.location}/{frame_component}")
+ return frame_filepaths
+
+ @staticmethod
+ def _remove_video_info_missing_or_incomplete_videos(
+ video_data_infos: Dict[str, Union[VideoFrameInfo, EncodedVideoInfo]],
+ video_infos: Dict[str, VideoInfo],
+ ) -> None:
+ # Avoid deletion keys from dict during iteration over keys
+ video_ids = list(video_infos)
+ for video_id in video_ids:
+ video_info = video_infos[video_id]
+
+ # Remove videos we have metadata for but don't have video data
+ if video_id not in video_data_infos:
+ del video_infos[video_id]
+ continue
+
+ # Remove videos we have metadata for but don't have the right number of frames
+ if type(video_data_infos[video_id]) == VideoFrameInfo:
+ video_frames_info = video_data_infos[video_id]
+ expected_frames = round(video_info.duration * video_info.fps)
+ num_frames = (
+ video_frames_info.max_frame_number
+ - video_frames_info.min_frame_number
+ )
+ if abs(num_frames - expected_frames) > video_info.fps:
+ del video_data_infos[video_id]
+ del video_infos[video_id]
+
+ video_ids = list(video_data_infos) # Avoid modifying dict during iteration
+ for video_id in video_ids:
+ # Remove videos we have video data for but don't have metadata
+ if video_id not in video_infos:
+
+ del video_data_infos[video_id]
+
+
+def get_seconds_from_hms_time(time_str: str) -> float:
+ """
+ Get Seconds from timestamp of form 'HH:MM:SS'.
+
+ Args:
+ time_str (str)
+
+ Returns:
+ float of seconds
+
+ """
+ for fmt in ("%H:%M:%S.%f", "%H:%M:%S"):
+ try:
+ time_since_min_time = datetime.datetime.strptime(time_str, fmt)
+ min_time = datetime.datetime.strptime("", "")
+ return float((time_since_min_time - min_time).total_seconds())
+ except ValueError:
+ pass
+ raise ValueError(f"No valid data format found for provided string {time_str}.")
+
+
+def save_encoded_video_manifest(
+ encoded_video_infos: Dict[str, EncodedVideoInfo], file_name: str = None
+) -> str:
+ """
+ Saves the encoded video dictionary as a csv file that can be read for future usage.
+
+ Args:
+ video_frames (Dict[str, EncodedVideoInfo]):
+ Dictionary mapping video_ids to metadata about the location of
+ their video data.
+
+ file_name (str):
+ location to save file (will be automatically generated if None).
+
+ Returns:
+ string of the filename where the video info is stored.
+ """
+ file_name = (
+ f"{os.getcwd()}/encoded_video_manifest.csv" if file_name is None else file_name
+ )
+ save_dataclass_objs_to_headered_csv(list(encoded_video_infos.values()), file_name)
+ return file_name
+
+
+def save_video_frame_info(
+ video_frames: Dict[str, VideoFrameInfo], file_name: str = None
+) -> str:
+ """
+ Saves the video frame dictionary as a csv file that can be read for future usage.
+
+ Args:
+ video_frames (Dict[str, VideoFrameInfo]):
+ Dictionary mapping video_ids to metadata about the location of
+ their video frame files.
+
+ file_name (str):
+ location to save file (will be automatically generated if None).
+
+ Returns:
+ string of the filename where the video info is stored.
+ """
+ file_name = (
+ f"{os.getcwd()}/video_frame_metadata.csv" if file_name is None else file_name
+ )
+ save_dataclass_objs_to_headered_csv(list(video_frames.values()), file_name)
+ return file_name
diff --git a/code/pytorchvideo/pytorchvideo/data/decoder.py b/code/pytorchvideo/pytorchvideo/data/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d5194ff20ca0c087a63ac530ff46a6ed6dda7ba
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/decoder.py
@@ -0,0 +1,8 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from enum import Enum
+
+
+class DecoderType(Enum):
+ PYAV = "pyav"
+ TORCHVISION = "torchvision"
+ DECORD = "decord"
diff --git a/code/pytorchvideo/pytorchvideo/data/domsev.py b/code/pytorchvideo/pytorchvideo/data/domsev.py
new file mode 100644
index 0000000000000000000000000000000000000000..74f07490c18442d2111ab082cd21797987b9f21e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/domsev.py
@@ -0,0 +1,532 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import math
+import random
+import time
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from iopath.common.file_io import g_pathmgr
+from PIL import Image
+from pytorchvideo.data.dataset_manifest_utils import (
+ ImageDataset,
+ ImageFrameInfo,
+ VideoClipInfo,
+ VideoDataset,
+ VideoDatasetType,
+)
+from pytorchvideo.data.utils import DataclassFieldCaster, load_dataclass_dict_from_csv
+from pytorchvideo.data.video import Video
+
+
+try:
+ import cv2
+except ImportError:
+ _HAS_CV2 = False
+else:
+ _HAS_CV2 = True
+
+
+USER_ENVIRONMENT_MAP = {
+ 0: "none",
+ 1: "indoor",
+ 2: "nature",
+ 3: "crowded_environment",
+ 4: "urban",
+}
+
+USER_ACTIVITY_MAP = {
+ 0: "none",
+ 1: "walking",
+ 2: "running",
+ 3: "standing",
+ 4: "biking",
+ 5: "driving",
+ 6: "playing",
+ 7: "cooking",
+ 8: "eating",
+ 9: "observing",
+ 10: "in_conversation",
+ 11: "browsing",
+ 12: "shopping",
+}
+
+USER_ATTENTION_MAP = {
+ 0: "none",
+ 1: "paying_attention",
+ 2: "interacting",
+}
+
+
+class LabelType(Enum):
+ Environment = 1
+ Activity = 2
+ UserAttention = 3
+
+
+LABEL_TYPE_2_MAP = {
+ LabelType.Environment: USER_ENVIRONMENT_MAP,
+ LabelType.Activity: USER_ACTIVITY_MAP,
+ LabelType.UserAttention: USER_ATTENTION_MAP,
+}
+
+
+@dataclass
+class LabelData(DataclassFieldCaster):
+ """
+ Class representing a contiguous label for a video segment from the DoMSEV dataset.
+ """
+
+ video_id: str
+ start_time: float # Start time of the label, in seconds
+ stop_time: float # Stop time of the label, in seconds
+ start_frame: int # 0-indexed ID of the start frame (inclusive)
+ stop_frame: int # 0-index ID of the stop frame (inclusive)
+ label_id: int
+ label_name: str
+
+
+# Utility functions
+def _seconds_to_frame_index(
+ time_in_seconds: float, fps: int, zero_indexed: Optional[bool] = True
+) -> int:
+ """
+ Converts a point in time (in seconds) within a video clip to its closest
+ frame indexed (rounding down), based on a specified frame rate.
+
+ Args:
+ time_in_seconds (float): The point in time within the video.
+ fps (int): The frame rate (frames per second) of the video.
+ zero_indexed (Optional[bool]): Whether the returned frame should be
+ zero-indexed (if True) or one-indexed (if False).
+
+ Returns:
+ (int) The index of the nearest frame (rounding down to the nearest integer).
+ """
+ frame_idx = math.floor(time_in_seconds * fps)
+ if not zero_indexed:
+ frame_idx += 1
+ return frame_idx
+
+
+def _get_overlap_for_time_range_pair(
+ t1_start: float, t1_stop: float, t2_start: float, t2_stop: float
+) -> Optional[Tuple[float, float]]:
+ """
+ Calculates the overlap between two time ranges, if one exists.
+
+ Returns:
+ (Optional[Tuple]) A tuple of if
+ an overlap is found, or None otherwise.
+ """
+ # Check if there is an overlap
+ if (t1_start <= t2_stop) and (t2_start <= t1_stop):
+ # Calculate the overlap period
+ overlap_start_time = max(t1_start, t2_start)
+ overlap_stop_time = min(t1_stop, t2_stop)
+ return (overlap_start_time, overlap_stop_time)
+ else:
+ return None
+
+
+class DomsevFrameDataset(torch.utils.data.Dataset):
+ """
+ Egocentric video classification frame-based dataset for
+ `DoMSEV `_
+
+ This dataset handles the loading, decoding, and configurable sampling for
+ the image frames.
+ """
+
+ def __init__(
+ self,
+ video_data_manifest_file_path: str,
+ video_info_file_path: str,
+ labels_file_path: str,
+ transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
+ multithreaded_io: bool = False,
+ ) -> None:
+ """
+ Args:
+ video_data_manifest_file_path (str):
+ The path to a json file outlining the available video data for the
+ associated videos. File must be a csv (w/header) with columns:
+ ``{[f.name for f in dataclass_fields(EncodedVideoInfo)]}``
+
+ To generate this file from a directory of video frames, see helper
+ functions in module: ``pytorchvideo.data.domsev.utils``
+
+ video_info_file_path (str):
+ Path or URI to manifest with basic metadata of each video.
+ File must be a csv (w/header) with columns:
+ ``{[f.name for f in dataclass_fields(VideoInfo)]}``
+
+ labels_file_path (str):
+ Path or URI to manifest with temporal annotations for each video.
+ File must be a csv (w/header) with columns:
+ ``{[f.name for f in dataclass_fields(LabelData)]}``
+
+ dataset_type (VideoDatasetType): The data format in which dataset
+ video data is stored (e.g. video frames, encoded video etc).
+
+ transform (Optional[Callable[[Dict[str, Any]], Any]]):
+ This callable is evaluated on the clip output before the clip is returned.
+ It can be used for user-defined preprocessing and augmentations to the clips.
+ The clip output format is described in __next__().
+
+ multithreaded_io (bool):
+ Boolean to control whether io operations are performed across multiple
+ threads.
+ """
+ assert video_info_file_path
+ assert labels_file_path
+ assert video_data_manifest_file_path
+
+ ## Populate image frame and metadata data providers ##
+ # Maps a image frame ID to an `ImageFrameInfo`
+ frames_dict: Dict[str, ImageFrameInfo] = ImageDataset._load_images(
+ video_data_manifest_file_path,
+ video_info_file_path,
+ multithreaded_io,
+ )
+ video_labels: Dict[str, List[LabelData]] = load_dataclass_dict_from_csv(
+ labels_file_path, LabelData, "video_id", list_per_key=True
+ )
+ # Maps an image frame ID to the singular frame label
+ self._labels_per_frame: Dict[
+ str, int
+ ] = DomsevFrameDataset._assign_labels_to_frames(frames_dict, video_labels)
+
+ self._user_transform = transform
+ self._transform = self._transform_frame
+
+ # Shuffle the frames order for iteration
+ self._frames = list(frames_dict.values())
+ random.shuffle(self._frames)
+
+ @staticmethod
+ def _assign_labels_to_frames(
+ frames_dict: Dict[str, ImageFrameInfo],
+ video_labels: Dict[str, List[LabelData]],
+ ):
+ """
+ Args:
+ frames_dict: The mapping of for all the frames
+ in the dataset.
+ video_labels: The list of temporal labels for each video
+
+ Also unpacks one label per frame.
+ Also converts them to class IDs and then a tensor.
+ """
+ labels_per_frame: Dict[str, int] = {}
+ for frame_id, image_info in frames_dict.items():
+ # Filter labels by only the ones that appear within the clip boundaries,
+ # and unpack the labels so there is one per frame in the clip
+ labels_in_video = video_labels[image_info.video_id]
+ for label in labels_in_video:
+ if (image_info.frame_number >= label.start_frame) and (
+ image_info.frame_number <= label.stop_frame
+ ):
+ labels_per_frame[frame_id] = label.label_id
+
+ return labels_per_frame
+
+ def __getitem__(self, index) -> Dict[str, Any]:
+ """
+ Samples an image frame associated to the given index.
+
+ Args:
+ index (int): index for the image frame
+
+ Returns:
+ An image frame with the following format if transform is None.
+
+ .. code-block:: text
+
+ {{
+ 'frame_id': ,
+ 'image': ,
+ 'label': ,
+ }}
+ """
+ frame = self._frames[index]
+ label_in_frame = self._labels_per_frame[frame.frame_id]
+
+ image_data = _load_image_from_path(frame.frame_file_path)
+
+ frame_data = {
+ "frame_id": frame.frame_id,
+ "image": image_data,
+ "label": label_in_frame,
+ }
+
+ if self._transform:
+ frame_data = self._transform(frame_data)
+
+ return frame_data
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ The number of frames in the dataset.
+ """
+ return len(self._frames)
+
+ def _transform_frame(self, frame: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Transforms a given image frame, according to some pre-defined transforms
+ and an optional user transform function (self._user_transform).
+
+ Args:
+ clip (Dict[str, Any]): The clip that will be transformed.
+
+ Returns:
+ (Dict[str, Any]) The transformed clip.
+ """
+ for key in frame:
+ if frame[key] is None:
+ frame[key] = torch.tensor([])
+
+ if self._user_transform:
+ frame = self._user_transform(frame)
+
+ return frame
+
+
+class DomsevVideoDataset(torch.utils.data.Dataset):
+ """
+ Egocentric classification video clip-based dataset for
+ `DoMSEV `_
+ stored as an encoded video (with frame-level labels).
+
+ This dataset handles the loading, decoding, and configurable clip
+ sampling for the videos.
+ """
+
+ def __init__(
+ self,
+ video_data_manifest_file_path: str,
+ video_info_file_path: str,
+ labels_file_path: str,
+ clip_sampler: Callable[
+ [Dict[str, Video], Dict[str, List[LabelData]]], List[VideoClipInfo]
+ ],
+ dataset_type: VideoDatasetType = VideoDatasetType.Frame,
+ frames_per_second: int = 1,
+ transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
+ frame_filter: Optional[Callable[[List[int]], List[int]]] = None,
+ multithreaded_io: bool = False,
+ ) -> None:
+ """
+ Args:
+ video_data_manifest_file_path (str):
+ The path to a json file outlining the available video data for the
+ associated videos. File must be a csv (w/header) with columns:
+ ``{[f.name for f in dataclass_fields(EncodedVideoInfo)]}``
+
+ To generate this file from a directory of video frames, see helper
+ functions in module: ``pytorchvideo.data.domsev.utils``
+
+ video_info_file_path (str):
+ Path or URI to manifest with basic metadata of each video.
+ File must be a csv (w/header) with columns:
+ ``{[f.name for f in dataclass_fields(VideoInfo)]}``
+
+ labels_file_path (str):
+ Path or URI to manifest with annotations for each video.
+ File must be a csv (w/header) with columns:
+ ``{[f.name for f in dataclass_fields(LabelData)]}``
+
+ clip_sampler (Callable[[Dict[str, Video], Dict[str, List[LabelData]]],
+ List[VideoClipInfo]]):
+ Defines how clips should be sampled from each video. See the clip
+ sampling documentation for more information.
+
+ dataset_type (VideoDatasetType): The data format in which dataset
+ video data is stored (e.g. video frames, encoded video etc).
+
+ frames_per_second (int): The FPS of the stored videos. (NOTE:
+ this is variable and may be different than the original FPS
+ reported on the DoMSEV dataset website -- it depends on the
+ preprocessed subsampling and frame extraction).
+
+ transform (Optional[Callable[[Dict[str, Any]], Any]]):
+ This callable is evaluated on the clip output before the clip is returned.
+ It can be used for user-defined preprocessing and augmentations to the clips.
+ The clip output format is described in __next__().
+
+ frame_filter (Optional[Callable[[List[int]], List[int]]]):
+ This callable is evaluated on the set of available frame indices to be
+ included in a sampled clip. This can be used to subselect frames within
+ a clip to be loaded.
+
+ multithreaded_io (bool):
+ Boolean to control whether io operations are performed across multiple
+ threads.
+ """
+ assert video_info_file_path
+ assert labels_file_path
+ assert video_data_manifest_file_path
+
+ # Populate video and metadata data providers
+ self._videos: Dict[str, Video] = VideoDataset._load_videos(
+ video_data_manifest_file_path,
+ video_info_file_path,
+ multithreaded_io,
+ dataset_type,
+ )
+
+ self._labels_per_video: Dict[
+ str, List[LabelData]
+ ] = load_dataclass_dict_from_csv(
+ labels_file_path, LabelData, "video_id", list_per_key=True
+ )
+
+ # Sample datapoints
+ self._clips: List[VideoClipInfo] = clip_sampler(
+ self._videos, self._labels_per_video
+ )
+
+ self._frames_per_second = frames_per_second
+ self._user_transform = transform
+ self._transform = self._transform_clip
+ self._frame_filter = frame_filter
+
+ def __getitem__(self, index) -> Dict[str, Any]:
+ """
+ Samples a video clip associated to the given index.
+
+ Args:
+ index (int): index for the video clip.
+
+ Returns:
+ A video clip with the following format if transform is None.
+
+ .. code-block:: text
+
+ {{
+ 'video_id': ,
+ 'video': ,
+ 'audio': ,
+ 'labels': ,
+ 'start_time': ,
+ 'stop_time':
+ }}
+ """
+ clip = self._clips[index]
+
+ # Filter labels by only the ones that appear within the clip boundaries,
+ # and unpack the labels so there is one per frame in the clip
+ labels_in_video = self._labels_per_video[clip.video_id]
+ labels_in_clip = []
+ for label_data in labels_in_video:
+ overlap_period = _get_overlap_for_time_range_pair(
+ clip.start_time,
+ clip.stop_time,
+ label_data.start_time,
+ label_data.stop_time,
+ )
+ if overlap_period is not None:
+ overlap_start_time, overlap_stop_time = overlap_period
+
+ # Convert the overlapping period between clip and label to
+ # 0-indexed start and stop frame indexes, so we can unpack 1
+ # label per frame.
+ overlap_start_frame = _seconds_to_frame_index(
+ overlap_start_time, self._frames_per_second
+ )
+ overlap_stop_frame = _seconds_to_frame_index(
+ overlap_stop_time, self._frames_per_second
+ )
+
+ # Append 1 label per frame
+ for _ in range(overlap_start_frame, overlap_stop_frame):
+ labels_in_clip.append(label_data)
+
+ # Convert the list of LabelData objects to a tensor of just the label IDs
+ label_ids = [labels_in_clip[i].label_id for i in range(len(labels_in_clip))]
+ label_ids_tensor = torch.tensor(label_ids)
+
+ clip_data = {
+ "video_id": clip.video_id,
+ **self._videos[clip.video_id].get_clip(clip.start_time, clip.stop_time),
+ "labels": label_ids_tensor,
+ "start_time": clip.start_time,
+ "stop_time": clip.stop_time,
+ }
+
+ if self._transform:
+ clip_data = self._transform(clip_data)
+
+ return clip_data
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ The number of video clips in the dataset.
+ """
+ return len(self._clips)
+
+ def _transform_clip(self, clip: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Transforms a given video clip, according to some pre-defined transforms
+ and an optional user transform function (self._user_transform).
+
+ Args:
+ clip (Dict[str, Any]): The clip that will be transformed.
+
+ Returns:
+ (Dict[str, Any]) The transformed clip.
+ """
+ for key in clip:
+ if clip[key] is None:
+ clip[key] = torch.tensor([])
+
+ if self._user_transform:
+ clip = self._user_transform(clip)
+
+ return clip
+
+
+def _load_image_from_path(image_path: str, num_retries: int = 10) -> Image:
+ """
+ Loads the given image path using PathManager and decodes it as an RGB image.
+
+ Args:
+ image_path (str): the path to the image.
+ num_retries (int): number of times to retry image reading to handle transient error.
+
+ Returns:
+ A PIL Image of the image RGB data with shape:
+ (channel, height, width). The frames are of type np.uint8 and
+ in the range [0 - 255]. Raises an exception if unable to load images.
+ """
+ if not _HAS_CV2:
+ raise ImportError(
+ "opencv2 is required to use FrameVideo. Please "
+ "install with 'pip install opencv-python'"
+ )
+
+ img_arr = None
+
+ for i in range(num_retries):
+ with g_pathmgr.open(image_path, "rb") as f:
+ img_str = np.frombuffer(f.read(), np.uint8)
+ img_bgr = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+ if img_rgb is not None:
+ img_arr = img_rgb
+ break
+ else:
+ logging.warning(f"Reading attempt {i}/{num_retries} failed.")
+ time.sleep(1e-6)
+
+ if img_arr is None:
+ raise Exception("Failed to load image from {}".format(image_path))
+
+ pil_image = Image.fromarray(img_arr)
+ return pil_image
diff --git a/code/pytorchvideo/pytorchvideo/data/ego4d/__init__.py b/code/pytorchvideo/pytorchvideo/data/ego4d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5385bb06ea15356d80d42bc99df9f814a0c3d617
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/ego4d/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .ego4d_dataset import Ego4dMomentsDataset
diff --git a/code/pytorchvideo/pytorchvideo/data/ego4d/ego4d_dataset.py b/code/pytorchvideo/pytorchvideo/data/ego4d/ego4d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e41c27a7a25f50fae68cb1be77eb48ab233e0f00
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/ego4d/ego4d_dataset.py
@@ -0,0 +1,622 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import csv
+import json
+import logging
+import os
+from bisect import bisect_left
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
+
+import numpy as np
+
+import torch
+import torch.autograd.profiler as profiler
+import torch.utils.data
+import torchaudio
+from iopath.common.file_io import g_pathmgr
+
+from pytorchvideo.data import LabeledVideoDataset
+from pytorchvideo.data.clip_sampling import ClipSampler
+from pytorchvideo.data.ego4d.utils import (
+ Ego4dImuDataBase,
+ get_label_id_map,
+ MomentsClipSampler,
+)
+from pytorchvideo.data.utils import get_logger
+from pytorchvideo.data.video import VideoPathHandler
+from pytorchvideo.transforms import (
+ ApplyTransformToKey,
+ Div255,
+ Normalize,
+ RandomShortSideScale,
+ ShortSideScale,
+)
+from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip
+
+log: logging.Logger = get_logger("Ego4dMomentsDataset")
+
+
+class Ego4dImuData(Ego4dImuDataBase):
+ """
+ Wrapper for Ego4D IMU data loads, assuming one csv per video_uid at the provided path.
+ """
+
+ def __init__(self, imu_path: str) -> None:
+ """
+ Args:
+ imu_path (str):
+ Base path to construct IMU csv file paths.
+ i.e. /.csv
+ """
+ assert imu_path
+
+ self.path_imu = imu_path
+ self.IMU_by_video_uid: Dict[str, Any] = {}
+ for f in g_pathmgr.ls(self.path_imu):
+ self.IMU_by_video_uid[f.split(".")[0]] = f.replace(".csv", "")
+
+ log.info(
+ f"Number of videos with IMU (before filtering) {len(self.IMU_by_video_uid)}"
+ )
+
+ self.imu_video_uid: Optional[str] = None
+ self.imu_video_data: Optional[Tuple[np.ndarray, np.ndarray, int]] = None
+
+ def has_imu(self, video_uid: str) -> bool:
+ return video_uid in self.IMU_by_video_uid
+
+ def _load_csv(self, csv_path: str) -> List[Dict[str, Any]]:
+ with g_pathmgr.open(csv_path, "r") as f:
+ reader = csv.DictReader(f)
+ data = []
+ for row in reader:
+ data.append(row)
+ return data
+
+ def _load_imu(self, video_uid: str) -> Tuple[np.ndarray, np.ndarray, int]:
+ file_path = os.path.join(self.path_imu, video_uid) + ".csv"
+ data_csv = self._load_csv(file_path)
+ data_IMU = defaultdict(list)
+ for row in data_csv:
+ for k, v in row.items():
+ if v != "":
+ data_IMU[k].append(float(v))
+ else:
+ data_IMU[k].append(0.0)
+ signal = np.array(
+ [
+ data_IMU["accl_x"],
+ data_IMU["accl_y"],
+ data_IMU["accl_z"],
+ data_IMU["gyro_x"],
+ data_IMU["gyro_y"],
+ data_IMU["gyro_z"],
+ ]
+ ).transpose()
+ # normalize
+ signal = (signal - signal.mean(axis=0)) / signal.std(axis=0)
+ timestamps = np.array(data_IMU["canonical_timestamp_ms"])
+ sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
+ if sampling_rate < 0:
+ # regenerate timestamps with 198 hz
+ new_timestamps = timestamps[0] + (1000 / 198) * np.arange(len(timestamps))
+ timestamps = np.array(new_timestamps)
+ sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
+ return signal, timestamps, sampling_rate
+
+ def _get_imu_window(
+ self,
+ window_start: float,
+ window_end: float,
+ signal: np.ndarray,
+ timestamps: np.ndarray,
+ sampling_rate: float,
+ ) -> Dict[str, Any]:
+ start_id = bisect_left(timestamps, window_start * 1000)
+ end_id = bisect_left(timestamps, window_end * 1000)
+ if end_id == len(timestamps):
+ end_id -= 1
+
+ sample_dict = {
+ "timestamp": timestamps[start_id:end_id],
+ "signal": signal[start_id:end_id],
+ "sampling_rate": sampling_rate,
+ }
+ return sample_dict
+
+ def get_imu(self, video_uid: str) -> Tuple[np.ndarray, np.ndarray, int]:
+ # Caching/etc?
+ return self._load_imu(video_uid)
+
+ def get_imu_sample(
+ self, video_uid: str, video_start: float, video_end: float
+ ) -> Dict[str, Any]:
+ # Assumes video clips are loaded sequentially, will lazy load imu
+ if not self.imu_video_uid or video_uid != self.imu_video_uid:
+ self.imu_video_uid = video_uid
+ self.imu_video_data = self._load_imu(video_uid)
+ assert self.imu_video_data
+ imu_signal, timestamps, sampling_rate = self.imu_video_data
+
+ return self._get_imu_window(
+ video_start,
+ video_end,
+ imu_signal,
+ timestamps,
+ sampling_rate,
+ )
+
+
+class Ego4dMomentsDataset(LabeledVideoDataset):
+ """
+ Ego4d video/audio/imu dataset for the moments benchmark:
+ ` `
+
+ This dataset handles the parsing of frames, loading and clip sampling for the
+ videos.
+
+ IO utilizing :code:`iopath.common.file_io.PathManager` to support
+ non-local storage uri's.
+ """
+
+ VIDEO_FPS = 30
+ AUDIO_FPS = 48000
+
+ def __init__(
+ self,
+ annotation_path: str,
+ metadata_path: str,
+ split: Optional[str] = None,
+ decode_audio: bool = True,
+ imu: bool = False,
+ clip_sampler: Optional[ClipSampler] = None,
+ video_sampler: Type[
+ torch.utils.data.Sampler
+ ] = torch.utils.data.SequentialSampler,
+ transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ decoder: str = "pyav",
+ filtered_labels: Optional[List[str]] = None,
+ window_sec: int = 10,
+ audio_transform_type: str = "melspectrogram",
+ imu_path: str = None,
+ label_id_map: Optional[Dict[str, int]] = None,
+ label_id_map_path: Optional[str] = None,
+ video_path_override: Optional[Callable[[str], str]] = None,
+ video_path_handler: Optional[VideoPathHandler] = None,
+ eligible_video_uids: Optional[Set[str]] = None,
+ ) -> None:
+ """
+ Args:
+ annotation_path (str):
+ Path or URI to Ego4d moments annotations json (ego4d.json). Download via:
+ ``
+
+ metadata_path (str):
+ Path or URI to primary Ego4d metadata json (moments.json). Download via:
+ ``
+
+ split (Optional[str]): train/val/test
+
+ decode_audio (bool): If True, decode audio from video.
+
+ imu (bool): If True, load IMU data.
+
+ clip_sampler (ClipSampler):
+ A standard PTV ClipSampler. By default, if not specified, `MomentsClipSampler`
+
+ video_sampler (VideoSampler):
+ A standard PTV VideoSampler.
+
+ transform (Optional[Callable[[Dict[str, Any]], Any]]):
+ This callable is evaluated on the clip output before the clip is returned.
+ It can be used for user-defined preprocessing and augmentations to the clips.
+
+ The clip input is a dictionary with the following format:
+ {{
+ 'video': ,
+ 'audio': ,
+ 'imu': ,
+ 'start_time': ,
+ 'stop_time':
+ }}
+
+ If transform is None, the raw clip output in the above format is
+ returned unmodified.
+
+ decoder (str): Defines what type of decoder used to decode a video within
+ `LabeledVideoDataset`.
+
+ filtered_labels (List[str]):
+ Optional list of moments labels to filter samples for training.
+
+ window_sec (int): minimum window size in s
+
+ audio_transform_type: melspectrogram / spectrogram / mfcc
+
+ imu_path (Optional[str]):
+ Path to the ego4d IMU csv file. Required if imu=True.
+
+ label_id_map / label_id_map_path:
+ A map of moments labels to consistent integer ids. If specified as a path
+ we expect a vanilla .json dict[str, int]. Exactly one must be specified.
+
+ video_path_override ((str) -> str):
+ An override for video paths, given the video_uid, to support downsampled/etc
+ videos.
+
+ video_path_handler (VideoPathHandler):
+ Primarily provided as an override for `CachedVideoPathHandler`
+
+ Example Usage:
+ Ego4dMomentsDataset(
+ annotation_path="~/ego4d_data/v1/annotations/moments.json",
+ metadata_path="~/ego4d_data/v1/ego4d.json",
+ split="train",
+ decode_audio=True,
+ imu=False,
+ )
+ """
+
+ assert annotation_path
+ assert metadata_path
+ assert split in [
+ "train",
+ "val",
+ "test",
+ ], f"Split '{split}' not supported for ego4d"
+ self.split: str = split
+ self.training: bool = split == "train"
+ self.window_sec = window_sec
+ self._transform_source = transform
+ self.decode_audio = decode_audio
+ self.audio_transform_type = audio_transform_type
+ assert (label_id_map is not None) ^ (
+ label_id_map_path is not None
+ ), f"Either label_id_map or label_id_map_path required ({label_id_map_path} / {label_id_map})" # noqa
+
+ self.video_means = (0.45, 0.45, 0.45)
+ self.video_stds = (0.225, 0.225, 0.225)
+ self.video_crop_size = 224
+ self.video_min_short_side_scale = 256
+ self.video_max_short_side_scale = 320
+
+ try:
+ with g_pathmgr.open(metadata_path, "r") as f:
+ metadata = json.load(f)
+ except Exception:
+ raise FileNotFoundError(
+ f"{metadata_path} must be a valid metadata json for Ego4D"
+ )
+
+ self.video_metadata_map: Dict[str, Any] = {
+ x["video_uid"]: x for x in metadata["videos"]
+ }
+
+ if not g_pathmgr.isfile(annotation_path):
+ raise FileNotFoundError(f"{annotation_path} not found.")
+
+ try:
+ with g_pathmgr.open(annotation_path, "r") as f:
+ moments_annotations = json.load(f)
+ except Exception:
+ raise FileNotFoundError(f"{annotation_path} must be json for Ego4D dataset")
+
+ self.label_name_id_map: Dict[str, int]
+ if label_id_map:
+ self.label_name_id_map = label_id_map
+ else:
+ self.label_name_id_map = get_label_id_map(label_id_map_path)
+ assert self.label_name_id_map
+
+ self.num_classes: int = len(self.label_name_id_map)
+ log.info(f"Label Classes: {self.num_classes}")
+
+ self.imu_data: Optional[Ego4dImuDataBase] = None
+ if imu:
+ assert imu_path, "imu_path not provided"
+ self.imu_data = Ego4dImuData(imu_path)
+
+ video_uids = set()
+ clip_uids = set()
+ clip_video_map = {}
+ labels = set()
+ labels_bypassed = set()
+ cnt_samples_bypassed = 0
+ cnt_samples_bypassed_labels = 0
+ samples = []
+
+ for vid in moments_annotations["videos"]:
+ video_uid = vid["video_uid"]
+ video_uids.add(video_uid)
+ vsplit = vid["split"]
+ if split and vsplit != split:
+ continue
+ # If IMU, filter videos without IMU
+ if self.imu_data and not self.imu_data.has_imu(video_uid):
+ continue
+ if eligible_video_uids and video_uid not in eligible_video_uids:
+ continue
+ for clip in vid["clips"]:
+ clip_uid = clip["clip_uid"]
+ clip_uids.add(clip_uid)
+ clip_video_map[clip_uid] = video_uid
+ clip_start_sec = clip["video_start_sec"]
+ clip_end_sec = clip["video_end_sec"]
+ for vann in clip["annotations"]:
+ for lann in vann["labels"]:
+ label = lann["label"]
+ labels.add(label)
+ start = lann["start_time"]
+ end = lann["end_time"]
+ # remove sample with same timestamp
+ if start == end:
+ continue
+ start_video = lann["video_start_time"]
+ end_video = lann["video_end_time"]
+ assert end_video >= start_video
+
+ if abs(start_video - (clip_start_sec + start)) > 0.5:
+ log.warning(
+ f"Suspect clip/video start mismatch: clip: {clip_start_sec:.2f} + {start:.2f} video: {start_video:.2f}" # noqa
+ )
+
+ # filter annotation base on the existing label map
+ if filtered_labels and label not in filtered_labels:
+ cnt_samples_bypassed += 1
+ labels_bypassed.add(label)
+ continue
+ metadata = self.video_metadata_map[video_uid]
+
+ if metadata["is_stereo"]:
+ cnt_samples_bypassed += 1
+ continue
+
+ if video_path_override:
+ video_path = video_path_override(video_uid)
+ else:
+ video_path = metadata["manifold_path"]
+ if not video_path:
+ cnt_samples_bypassed += 1
+ log.error("Bypassing invalid video_path: {video_uid}")
+ continue
+
+ sample = {
+ "clip_uid": clip_uid,
+ "video_uid": video_uid,
+ "duration": metadata["duration_sec"],
+ "clip_video_start_sec": clip_start_sec,
+ "clip_video_end_sec": clip_end_sec,
+ "labels": [label],
+ "label_video_start_sec": start_video,
+ "label_video_end_sec": end_video,
+ "video_path": video_path,
+ }
+ assert (
+ sample["label_video_end_sec"]
+ > sample["label_video_start_sec"]
+ )
+
+ if self.label_name_id_map:
+ if label in self.label_name_id_map:
+ sample["labels_id"] = self.label_name_id_map[label]
+ else:
+ cnt_samples_bypassed_labels += 1
+ continue
+ else:
+ log.error("Missing label_name_id_map")
+ samples.append(sample)
+
+ self.cnt_samples: int = len(samples)
+
+ log.info(
+ f"Loaded {self.cnt_samples} samples. Bypass: {cnt_samples_bypassed} Label Lookup Bypass: {cnt_samples_bypassed_labels}" # noqa
+ )
+
+ for sample in samples:
+ assert "labels_id" in sample, f"init: Sample missing labels_id: {sample}"
+
+ if not clip_sampler:
+ clip_sampler = MomentsClipSampler(self.window_sec)
+
+ super().__init__(
+ [(x["video_path"], x) for x in samples],
+ clip_sampler,
+ video_sampler,
+ transform=self._transform_mm,
+ decode_audio=decode_audio,
+ decoder=decoder,
+ )
+
+ if video_path_handler:
+ self.video_path_handler = video_path_handler
+
+ def check_IMU(self, input_dict: Dict[str, Any]) -> bool:
+ if (
+ len(input_dict["imu"]["signal"].shape) != 2
+ or input_dict["imu"]["signal"].shape[0] == 0
+ or input_dict["imu"]["signal"].shape[0] < 200
+ or input_dict["imu"]["signal"].shape[1] != 6
+ ):
+ log.warning(f"Problematic Sample: {input_dict}")
+ return True
+ else:
+ return False
+
+ def _transform_mm(self, sample_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ log.info("_transform_mm")
+ with profiler.record_function("_transform_mm"):
+ video_uid = sample_dict["video_uid"]
+ assert video_uid
+
+ assert sample_dict["video"] is not None
+ assert (
+ "labels_id" in sample_dict
+ ), f"Sample missing labels_id: {sample_dict}"
+
+ video = sample_dict["video"]
+
+ expected = int(self.VIDEO_FPS * self.window_sec)
+ actual = video.size(1)
+ if expected != actual:
+ log.error(
+ f"video size mismatch: actual: {actual} expected: {expected} video: {video.size()} uid: {video_uid}", # noqa
+ stack_info=True,
+ )
+ return None
+
+ start = sample_dict["clip_start"]
+ end = sample_dict["clip_end"]
+ assert start >= 0 and end >= start
+
+ if abs((end - start) - self.window_sec) > 0.01:
+ log.warning(f"Invalid IMU time window: ({start}, {end})")
+
+ if self.imu_data:
+ sample_dict["imu"] = self.imu_data.get_imu_sample(
+ video_uid,
+ start,
+ end,
+ )
+ if self.check_IMU(sample_dict):
+ log.warning(f"Bad IMU sample: ignoring: {video_uid}")
+ return None
+
+ sample_dict = self._video_transform()(sample_dict)
+
+ if self.decode_audio:
+ audio_fps = self.AUDIO_FPS
+ sample_dict["audio"] = self._preproc_audio(
+ sample_dict["audio"], audio_fps
+ )
+ sample_dict["spectrogram"] = sample_dict["audio"]["spectrogram"]
+
+ labels = sample_dict["labels"]
+ one_hot = self.convert_one_hot(labels)
+ sample_dict["labels_onehot"] = one_hot
+
+ if self._transform_source:
+ sample_dict = self._transform_source(sample_dict)
+
+ log.info(
+ f"Sample ({sample_dict['video_name']}): "
+ f"({sample_dict['clip_start']:.2f}, {sample_dict['clip_end']:.2f}) "
+ f" {sample_dict['labels_id']} | {sample_dict['labels']}"
+ )
+
+ return sample_dict
+
+ # pyre-ignore
+ def _video_transform(self):
+ """
+ This function contains example transforms using both PyTorchVideo and
+ TorchVision in the same callable. For 'train' model, we use augmentations (prepended
+ with 'Random'), for 'val' we use the respective deterministic function
+ """
+
+ assert (
+ self.video_means
+ and self.video_stds
+ and self.video_min_short_side_scale > 0
+ and self.video_crop_size > 0
+ )
+
+ video_transforms = ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ # pyre-fixme
+ [Div255(), Normalize(self.video_means, self.video_stds)]
+ + [ # pyre-fixme
+ RandomShortSideScale(
+ min_size=self.video_min_short_side_scale,
+ max_size=self.video_max_short_side_scale,
+ ),
+ RandomCrop(self.video_crop_size),
+ RandomHorizontalFlip(p=0.5),
+ ]
+ if self.training
+ else [
+ ShortSideScale(self.video_min_short_side_scale),
+ CenterCrop(self.video_crop_size),
+ ]
+ ),
+ )
+ return Compose([video_transforms])
+
+ def signal_transform(self, type: str = "spectrogram", sample_rate: int = 48000):
+ if type == "spectrogram":
+ n_fft = 1024
+ win_length = None
+ hop_length = 512
+
+ transform = torchaudio.transforms.Spectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ center=True,
+ pad_mode="reflect",
+ power=2.0,
+ )
+ elif type == "melspectrogram":
+ n_fft = 1024
+ win_length = None
+ hop_length = 512
+ n_mels = 64
+
+ transform = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sample_rate,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ center=True,
+ pad_mode="reflect",
+ power=2.0,
+ norm="slaney",
+ onesided=True,
+ n_mels=n_mels,
+ mel_scale="htk",
+ )
+ elif type == "mfcc":
+ n_fft = 2048
+ win_length = None
+ hop_length = 512
+ n_mels = 256
+ n_mfcc = 256
+
+ transform = torchaudio.transforms.MFCC(
+ sample_rate=sample_rate,
+ n_mfcc=n_mfcc,
+ melkwargs={
+ "n_fft": n_fft,
+ "n_mels": n_mels,
+ "hop_length": hop_length,
+ "mel_scale": "htk",
+ },
+ )
+ else:
+ raise ValueError(type)
+
+ return transform
+
+ def _preproc_audio(self, audio, audio_fps) -> Dict[str, Any]:
+ # convert stero to mono
+ # https://github.com/pytorch/audio/issues/363
+ waveform_mono = torch.mean(audio, dim=0, keepdim=True)
+ return {
+ "signal": waveform_mono,
+ "spectrogram": self.signal_transform(
+ type=self.audio_transform_type,
+ sample_rate=audio_fps,
+ )(waveform_mono),
+ "sampling_rate": audio_fps,
+ }
+
+ def convert_one_hot(self, label_list: List[str]) -> List[int]:
+ labels = [x for x in label_list if x in self.label_name_id_map.keys()]
+ assert len(labels) == len(
+ label_list
+ ), f"invalid filter {len(label_list)} -> {len(labels)}: {label_list}"
+ one_hot = [0 for _ in range(self.num_classes)]
+ for lab in labels:
+ one_hot[self.label_name_id_map[lab]] = 1
+ return one_hot
diff --git a/code/pytorchvideo/pytorchvideo/data/ego4d/utils.py b/code/pytorchvideo/pytorchvideo/data/ego4d/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..186004fda11a8287321b14d662fa0701ab0fd16f
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/ego4d/utils.py
@@ -0,0 +1,124 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Tuple
+
+from iopath.common.file_io import g_pathmgr
+
+from pytorchvideo.data.clip_sampling import ClipInfo, ClipSampler
+from pytorchvideo.data.utils import get_logger
+
+log: logging.Logger = get_logger("Ego4dDatasetUtils")
+
+
+# TODO: Round to fps (and ideally frame align)
+def check_window_len(
+ s_time: float, e_time: float, w_len: float, video_dur: float
+) -> Tuple[float, float]:
+ """
+ Constrain/slide the give time window to `w_len` size and the video/clip length.
+ """
+ # adjust to match w_len
+ interval = e_time - s_time
+ if abs(interval - w_len) > 0.001:
+ # TODO: Do we want to sample rather than trim the interior when larger?
+ delta = w_len - (e_time - s_time)
+ s_time = s_time - (delta / 2)
+ e_time = e_time + (delta / 2)
+ if s_time < 0:
+ e_time += -s_time
+ s_time = 0
+ if video_dur:
+ if e_time > video_dur:
+ overlap = e_time - video_dur
+ assert s_time >= overlap, "Incompatible w_len / video_dur"
+ s_time -= overlap
+ e_time -= overlap
+ log.info(
+ f"check_window_len: video overlap ({overlap}) adjusted -> ({s_time:.2f}, {e_time:.2f}) video: {video_dur}" # noqa
+ )
+ if abs((e_time - s_time) - w_len) > 0.01:
+ log.error(
+ f"check_window_len: invalid time interval: {s_time}, {e_time}",
+ stack_info=True,
+ )
+ return s_time, e_time
+
+
+# TODO: Move to FixedClipSampler?
+class MomentsClipSampler(ClipSampler):
+ """
+ ClipSampler for Ego4d moments. Will return a fixed `window_sec` window
+ around the given annotation, shifting where relevant to account for the end
+ of the clip/video.
+
+ clip_start/clip_end is added to the annotation dict to facilitate future lookups.
+ """
+
+ def __init__(self, window_sec: float = 0) -> None:
+ self.window_sec = window_sec
+
+ def __call__(
+ self,
+ last_clip_end_time: float,
+ video_duration: float,
+ annotation: Dict[str, Any],
+ ) -> ClipInfo:
+ assert (
+ last_clip_end_time is None or last_clip_end_time <= video_duration
+ ), f"last_clip_end_time ({last_clip_end_time}) > video_duration ({video_duration})"
+ start = annotation["label_video_start_sec"]
+ end = annotation["label_video_end_sec"]
+ if video_duration is not None and end > video_duration:
+ log.error(f"Invalid video_duration/end_sec: {video_duration} / {end}")
+ # If it's small, proceed anyway
+ if end > video_duration + 0.1:
+ raise Exception(
+ f"Invalid video_duration/end_sec: {video_duration} / {end} ({annotation['video_name']})" # noqa
+ )
+ assert end >= start, f"end < start: {end:.2f} / {start:.2f}"
+ if self.window_sec > 0:
+ s, e = check_window_len(start, end, self.window_sec, video_duration)
+ if s != start or e != end:
+ # log.info(
+ # f"clip window slid ({start:.2f}|{end:.2f}) -> ({s:.2f}|{e:.2f})"
+ # )
+ start = s
+ end = e
+ annotation["clip_start"] = start
+ annotation["clip_end"] = end
+ return ClipInfo(start, end, 0, 0, True)
+
+
+def get_label_id_map(label_id_map_path: str) -> Dict[str, int]:
+ label_name_id_map: Dict[str, int]
+
+ try:
+ with g_pathmgr.open(label_id_map_path, "r") as f:
+ label_json = json.load(f)
+
+ # TODO: Verify?
+ return label_json
+ except Exception:
+ raise FileNotFoundError(f"{label_id_map_path} must be a valid label id json")
+
+
+class Ego4dImuDataBase(ABC):
+ """
+ Base class placeholder for Ego4d IMU data.
+ """
+
+ def __init__(self, basepath: str):
+ self.basepath = basepath
+
+ @abstractmethod
+ def has_imu(self, video_uid: str) -> bool:
+ pass
+
+ @abstractmethod
+ def get_imu_sample(
+ self, video_uid: str, video_start: float, video_end: float
+ ) -> Dict[str, Any]:
+ pass
diff --git a/code/pytorchvideo/pytorchvideo/data/encoded_video.py b/code/pytorchvideo/pytorchvideo/data/encoded_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..227227adcd9a16e04bebdac9bf3c5ffb1cd37982
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/encoded_video.py
@@ -0,0 +1,75 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import io
+import logging
+import pathlib
+from typing import Any, Dict
+
+from iopath.common.file_io import g_pathmgr
+from pytorchvideo.data.decoder import DecoderType
+
+from .video import Video
+
+
+logger = logging.getLogger(__name__)
+
+
+def select_video_class(decoder: str) -> Video:
+ """
+ Select the class for accessing clips based on provided decoder string
+
+ Args:
+ decoder (str): Defines what type of decoder used to decode a video.
+ """
+ if DecoderType(decoder) == DecoderType.PYAV:
+ from .encoded_video_pyav import EncodedVideoPyAV
+
+ video_cls = EncodedVideoPyAV
+ elif DecoderType(decoder) == DecoderType.TORCHVISION:
+ from .encoded_video_torchvision import EncodedVideoTorchVision
+
+ video_cls = EncodedVideoTorchVision
+ elif DecoderType(decoder) == DecoderType.DECORD:
+ from .encoded_video_decord import EncodedVideoDecord
+
+ video_cls = EncodedVideoDecord
+ else:
+ raise NotImplementedError(f"Unknown decoder type {decoder}")
+
+ return video_cls
+
+
+class EncodedVideo(Video):
+ """
+ EncodedVideo is an abstraction for accessing clips from an encoded video.
+ It supports selective decoding when header information is available.
+ """
+
+ @classmethod
+ def from_path(
+ cls,
+ file_path: str,
+ decode_video: bool = True,
+ decode_audio: bool = True,
+ decoder: str = "pyav",
+ **other_args: Dict[str, Any],
+ ):
+ """
+ Fetches the given video path using PathManager (allowing remote uris to be
+ fetched) and constructs the EncodedVideo object.
+
+ Args:
+ file_path (str): a PathManager file-path.
+ """
+ # We read the file with PathManager so that we can read from remote uris.
+ with g_pathmgr.open(file_path, "rb") as fh:
+ video_file = io.BytesIO(fh.read())
+
+ video_cls = select_video_class(decoder)
+ return video_cls(
+ file=video_file,
+ video_name=pathlib.Path(file_path).name,
+ decode_video=decode_video,
+ decode_audio=decode_audio,
+ **other_args,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/data/encoded_video_decord.py b/code/pytorchvideo/pytorchvideo/data/encoded_video_decord.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ae85dc04011e6c1463ab0fa1ebcbb9cd1b2ea44
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/encoded_video_decord.py
@@ -0,0 +1,199 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import math
+from typing import BinaryIO, Dict, Optional, TypeVar
+
+import torch
+
+from .utils import thwc_to_cthw
+from .video import Video
+
+
+logger = logging.getLogger(__name__)
+
+try:
+ import decord
+except ImportError:
+ _HAS_DECORD = False
+else:
+ _HAS_DECORD = True
+
+if _HAS_DECORD:
+ decord.bridge.set_bridge("torch")
+
+DecordDevice = TypeVar("DecordDevice")
+
+
+class EncodedVideoDecord(Video):
+ """
+
+ Accessing clips from an encoded video using Decord video reading API
+ as the decoding backend. For more details, please refer to -
+ `Decord `
+ """
+
+ def __init__(
+ self,
+ file: BinaryIO,
+ video_name: Optional[str] = None,
+ decode_video: bool = True,
+ decode_audio: bool = True,
+ sample_rate: int = 44100,
+ mono: bool = True,
+ width: int = -1,
+ height: int = -1,
+ num_threads: int = 0,
+ fault_tol: int = -1,
+ ) -> None:
+ """
+ Args:
+ file (BinaryIO): a file-like object (e.g. io.BytesIO or io.StringIO) that
+ contains the encoded video.
+ video_name (str): An optional name assigned to the video.
+ decode_video (bool): If disabled, video is not decoded.
+ decode_audio (bool): If disabled, audio is not decoded.
+ sample_rate: int, default is -1
+ Desired output sample rate of the audio, unchanged if `-1` is specified.
+ mono: bool, default is True
+ Desired output channel layout of the audio. `True` is mono layout. `False`
+ is unchanged.
+ width : int, default is -1
+ Desired output width of the video, unchanged if `-1` is specified.
+ height : int, default is -1
+ Desired output height of the video, unchanged if `-1` is specified.
+ num_threads : int, default is 0
+ Number of decoding thread, auto if `0` is specified.
+ fault_tol : int, default is -1
+ The threshold of corupted and recovered frames. This is to prevent silent fault
+ tolerance when for example 50% frames of a video cannot be decoded and duplicate
+ frames are returned. You may find the fault tolerant feature sweet in many
+ cases, but not for training models. Say `N = # recovered frames`
+ If `fault_tol` < 0, nothing will happen.
+ If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`,
+ raise `DECORDLimitReachedError`.
+ If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
+ """
+ if not decode_video:
+ raise NotImplementedError()
+
+ self._decode_audio = decode_audio
+ self._video_name = video_name
+ if not _HAS_DECORD:
+ raise ImportError(
+ "decord is required to use EncodedVideoDecord decoder. Please "
+ "install with 'pip install decord' for CPU-only version and refer to"
+ "'https://github.com/dmlc/decord' for GPU-supported version"
+ )
+ try:
+ if self._decode_audio:
+ self._av_reader = decord.AVReader(
+ uri=file,
+ ctx=decord.cpu(0),
+ sample_rate=sample_rate,
+ mono=mono,
+ width=width,
+ height=height,
+ num_threads=num_threads,
+ fault_tol=fault_tol,
+ )
+ else:
+ self._av_reader = decord.VideoReader(
+ uri=file,
+ ctx=decord.cpu(0),
+ width=width,
+ height=height,
+ num_threads=num_threads,
+ fault_tol=fault_tol,
+ )
+ except Exception as e:
+ raise RuntimeError(f"Failed to open video {video_name} with Decord. {e}")
+
+ if self._decode_audio:
+ self._fps = self._av_reader._AVReader__video_reader.get_avg_fps()
+ else:
+ self._fps = self._av_reader.get_avg_fps()
+
+ self._duration = float(len(self._av_reader)) / float(self._fps)
+
+ @property
+ def name(self) -> Optional[str]:
+ """
+ Returns:
+ name: the name of the stored video if set.
+ """
+ return self._video_name
+
+ @property
+ def duration(self) -> float:
+ """
+ Returns:
+ duration: the video's duration/end-time in seconds.
+ """
+ return self._duration
+
+ def close(self):
+ if self._av_reader is not None:
+ del self._av_reader
+ self._av_reader = None
+
+ def get_clip(
+ self, start_sec: float, end_sec: float
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """
+ Retrieves frames from the encoded video at the specified start and end times
+ in seconds (the video always starts at 0 seconds).
+
+ Args:
+ start_sec (float): the clip start time in seconds
+ end_sec (float): the clip end time in seconds
+ Returns:
+ clip_data:
+ A dictionary mapping the entries at "video" and "audio" to a tensors.
+
+ "video": A tensor of the clip's RGB frames with shape:
+ (channel, time, height, width). The frames are of type torch.float32 and
+ in the range [0 - 255].
+
+ "audio": A tensor of the clip's audio samples with shape:
+ (samples). The samples are of type torch.float32 and
+ in the range [0 - 255].
+
+ Returns None if no video or audio found within time range.
+
+ """
+ if start_sec > end_sec or start_sec > self._duration:
+ raise RuntimeError(
+ f"Incorrect time window for Decord decoding for video: {self._video_name}."
+ )
+
+ start_idx = math.ceil(self._fps * start_sec)
+ end_idx = math.ceil(self._fps * end_sec)
+ end_idx = min(end_idx, len(self._av_reader))
+ frame_idxs = list(range(start_idx, end_idx))
+ audio = None
+
+ try:
+ outputs = self._av_reader.get_batch(frame_idxs)
+ except Exception as e:
+ logger.debug(f"Failed to decode video with Decord: {self._video_name}. {e}")
+ raise e
+
+ if self._decode_audio:
+ audio, video = outputs
+ if audio is not None:
+ audio = list(audio)
+ audio = torch.cat(audio, dim=1)
+ audio = torch.flatten(audio)
+ audio = audio.to(torch.float32)
+ else:
+ video = outputs
+
+ if video is not None:
+ video = video.to(torch.float32)
+ video = thwc_to_cthw(video)
+
+ return {
+ "video": video,
+ "audio": audio,
+ }
diff --git a/code/pytorchvideo/pytorchvideo/data/encoded_video_pyav.py b/code/pytorchvideo/pytorchvideo/data/encoded_video_pyav.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e9523815a58dc214ad1c956979dc13865c2045d
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/encoded_video_pyav.py
@@ -0,0 +1,364 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import math
+from fractions import Fraction
+from typing import BinaryIO, Dict, List, Optional, Tuple, Union
+
+import av
+import numpy as np
+import torch
+from pytorchvideo.data.encoded_video import EncodedVideo
+
+from .utils import pts_to_secs, secs_to_pts, thwc_to_cthw
+
+
+logger = logging.getLogger(__name__)
+
+
+class EncodedVideoPyAV(EncodedVideo):
+ """
+ EncodedVideoPyAV is an abstraction for accessing clips from an encoded video using
+ PyAV as the decoding backend. It supports selective decoding when header information
+ is available.
+ """
+
+ def __init__(
+ self,
+ file: BinaryIO,
+ video_name: Optional[str] = None,
+ decode_video: bool = True,
+ decode_audio: bool = True,
+ perform_seek: bool = True,
+ ) -> None:
+ """
+ Args:
+ file (BinaryIO): a file-like object (e.g. io.BytesIO or io.StringIO) that
+ contains the encoded video.
+ perform_seek:
+ Whether or not to seek time to the underlying video container.
+
+ NOTE: seeks may be slow on larger files, e.g. on a networked filesystem
+ """
+ self.perform_seek = perform_seek
+ self._video_name = video_name
+ self._decode_video = decode_video
+ self._decode_audio = decode_audio
+
+ try:
+ self._container = av.open(file)
+ except Exception as e:
+ raise RuntimeError(f"Failed to open video {video_name}. {e}")
+
+ if self._container is None or len(self._container.streams.video) == 0:
+ raise RuntimeError(f"Video stream not found {video_name}")
+
+ # Retrieve video header information if available.
+ video_stream = self._container.streams.video[0]
+ self._video_time_base = video_stream.time_base
+ self._video_start_pts = video_stream.start_time
+ if self._video_start_pts is None:
+ self._video_start_pts = 0.0
+
+ video_duration = video_stream.duration
+
+ # Retrieve audio header information if available.
+ audio_duration = None
+ self._has_audio = None
+ if self._decode_audio:
+ self._has_audio = self._container.streams.audio
+ if self._has_audio:
+ self._audio_time_base = self._container.streams.audio[0].time_base
+ self._audio_start_pts = self._container.streams.audio[0].start_time
+ if self._audio_start_pts is None:
+ self._audio_start_pts = 0.0
+
+ audio_duration = self._container.streams.audio[0].duration
+
+ # If duration isn't found in header the whole video is decoded to
+ # determine the duration.
+ self._video, self._audio, self._selective_decoding = (None, None, True)
+ if audio_duration is None and video_duration is None:
+ self._selective_decoding = False
+ self._video, self._audio = self._pyav_decode_video()
+ if self._video is None:
+ raise RuntimeError("Unable to decode video stream")
+
+ video_duration = self._video[-1][1]
+ if self._audio is not None:
+ audio_duration = self._audio[-1][1]
+
+ # Take the largest duration of either video or duration stream.
+ if audio_duration is not None and video_duration is not None:
+ self._duration = max(
+ pts_to_secs(
+ video_duration, self._video_time_base, self._video_start_pts
+ ),
+ pts_to_secs(
+ audio_duration, self._audio_time_base, self._audio_start_pts
+ ),
+ )
+ elif video_duration is not None:
+ self._duration = pts_to_secs(
+ video_duration, self._video_time_base, self._video_start_pts
+ )
+
+ elif audio_duration is not None:
+ self._duration = pts_to_secs(
+ audio_duration, self._audio_time_base, self._audio_start_pts
+ )
+
+ @property
+ def rate(self) -> Union[str, Fraction]:
+ """
+ Returns:
+ rate: the frame rate of the video
+ """
+ return self._container.streams.video[0].rate
+
+ @property
+ def bit_rate(self) -> int:
+ """
+ Returns:
+ bit_rate: the bit rate of the underlying video
+ """
+ return self._container.streams.video[0].bit_rate
+
+ @property
+ def pix_fmt(self) -> int:
+ """
+ Returns:
+ pix_fmt: the pixel format of the underlying video
+ """
+ return self._container.streams.video[0].pix_fmt
+
+ @property
+ def name(self) -> Optional[str]:
+ """
+ Returns:
+ name: the name of the stored video if set.
+ """
+ return self._video_name
+
+ @property
+ def duration(self) -> float:
+ """
+ Returns:
+ duration: the video's duration/end-time in seconds.
+ """
+ return self._duration
+
+ def get_clip(
+ self, start_sec: float, end_sec: float
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """
+ Retrieves frames from the encoded video at the specified start and end times
+ in seconds (the video always starts at 0 seconds). Returned frames will be in
+ [start_sec, end_sec). Note that 1) if you want to avoid float precision issue
+ and need accurate frames, please use Fraction for start_sec and end_sec.
+ 2) As end_sec is exclusive, so you may need to use
+ `get_clip(start_sec, duration + EPS)` to get the last frame.
+
+ Args:
+ start_sec (float): the clip start time in seconds
+ end_sec (float): the clip end time in seconds
+ Returns:
+ clip_data:
+ A dictionary mapping the entries at "video" and "audio" to a tensors.
+
+ "video": A tensor of the clip's RGB frames with shape:
+ (channel, time, height, width). The frames are of type torch.float32 and
+ in the range [0 - 255].
+
+ "audio": A tensor of the clip's audio samples with shape:
+ (samples). The samples are of type torch.float32 and
+ in the range [0 - 255].
+
+ Returns None if no video or audio found within time range.
+
+ """
+ if self._selective_decoding:
+ self._video, self._audio = self._pyav_decode_video(start_sec, end_sec)
+
+ video_frames = None
+ if self._video is not None:
+ video_start_pts = secs_to_pts(
+ start_sec,
+ self._video_time_base,
+ self._video_start_pts,
+ round_mode="ceil",
+ )
+ video_end_pts = secs_to_pts(
+ end_sec,
+ self._video_time_base,
+ self._video_start_pts,
+ round_mode="ceil",
+ )
+
+ video_frames = [
+ f
+ for f, pts in self._video
+ if pts >= video_start_pts and pts < video_end_pts
+ ]
+
+ audio_samples = None
+ if self._has_audio and self._audio is not None:
+ audio_start_pts = secs_to_pts(
+ start_sec,
+ self._audio_time_base,
+ self._audio_start_pts,
+ round_mode="ceil",
+ )
+ audio_end_pts = secs_to_pts(
+ end_sec,
+ self._audio_time_base,
+ self._audio_start_pts,
+ round_mode="ceil",
+ )
+ audio_samples = [
+ f
+ for f, pts in self._audio
+ if pts >= audio_start_pts and pts < audio_end_pts
+ ]
+ audio_samples = torch.cat(audio_samples, axis=0)
+ audio_samples = audio_samples.to(torch.float32)
+
+ if video_frames is None or len(video_frames) == 0:
+ logger.debug(
+ f"No video found within {start_sec} and {end_sec} seconds. "
+ f"Video starts at time 0 and ends at {self.duration}."
+ )
+
+ video_frames = None
+
+ if video_frames is not None:
+ video_frames = thwc_to_cthw(torch.stack(video_frames)).to(torch.float32)
+
+ return {
+ "video": video_frames,
+ "audio": audio_samples,
+ }
+
+ def close(self):
+ """
+ Closes the internal video container.
+ """
+ if self._container is not None:
+ self._container.close()
+
+ def _pyav_decode_video(
+ self, start_secs: float = 0.0, end_secs: float = math.inf
+ ) -> float:
+ """
+ Selectively decodes a video between start_pts and end_pts in time units of the
+ self._video's timebase.
+ """
+ video_and_pts = None
+ audio_and_pts = None
+ try:
+ if self._decode_video:
+ pyav_video_frames, _ = _pyav_decode_stream(
+ self._container,
+ secs_to_pts(
+ start_secs,
+ self._video_time_base,
+ self._video_start_pts,
+ round_mode="ceil",
+ ),
+ secs_to_pts(
+ end_secs,
+ self._video_time_base,
+ self._video_start_pts,
+ round_mode="ceil",
+ ),
+ self._container.streams.video[0],
+ {"video": 0},
+ perform_seek=self.perform_seek,
+ )
+ if len(pyav_video_frames) > 0:
+ video_and_pts = [
+ (torch.from_numpy(frame.to_rgb().to_ndarray()), frame.pts)
+ for frame in pyav_video_frames
+ ]
+
+ if self._has_audio:
+ pyav_audio_frames, _ = _pyav_decode_stream(
+ self._container,
+ secs_to_pts(
+ start_secs,
+ self._audio_time_base,
+ self._audio_start_pts,
+ round_mode="ceil",
+ ),
+ secs_to_pts(
+ end_secs,
+ self._audio_time_base,
+ self._audio_start_pts,
+ round_mode="ceil",
+ ),
+ self._container.streams.audio[0],
+ {"audio": 0},
+ perform_seek=self.perform_seek,
+ )
+
+ if len(pyav_audio_frames) > 0:
+ audio_and_pts = [
+ (
+ torch.from_numpy(np.mean(frame.to_ndarray(), axis=0)),
+ frame.pts,
+ )
+ for frame in pyav_audio_frames
+ ]
+
+ except Exception as e:
+ logger.debug(f"Failed to decode video: {self._video_name}. {e}")
+
+ return video_and_pts, audio_and_pts
+
+
+def _pyav_decode_stream(
+ container: av.container.input.InputContainer,
+ start_pts: int,
+ end_pts: int,
+ stream: av.video.stream.VideoStream,
+ stream_name: dict,
+ buffer_size: int = 0,
+ perform_seek: bool = True,
+) -> Tuple[List, float]:
+ """
+ Decode the video with PyAV decoder.
+ Args:
+ container (container): PyAV container.
+ start_pts (int): the starting Presentation TimeStamp to fetch the
+ video frames.
+ end_pts (int): the ending Presentation TimeStamp of the decoded frames.
+ stream (stream): PyAV stream.
+ stream_name (dict): a dictionary of streams. For example, {"video": 0}
+ means video stream at stream index 0.
+ Returns:
+ result (list): list of decoded frames.
+ max_pts (int): max Presentation TimeStamp of the video sequence.
+ """
+
+ # Seeking in the stream is imprecise. Thus, seek to an earlier pts by a
+ # margin pts.
+ margin = 1024
+
+ # NOTE:
+ # Don't want to seek if iterating through a video due to slow-downs. I
+ # believe this is some PyAV bug where seeking after a certain point causes
+ # major slow-downs
+ if perform_seek:
+ seek_offset = max(start_pts - margin, 0)
+ container.seek(int(seek_offset), any_frame=False, backward=True, stream=stream)
+ frames = {}
+ max_pts = 0
+ for frame in container.decode(**stream_name):
+ max_pts = max(max_pts, frame.pts)
+ if frame.pts >= start_pts and frame.pts < end_pts:
+ frames[frame.pts] = frame
+ elif frame.pts >= end_pts:
+ break
+
+ result = [frames[pts] for pts in sorted(frames)]
+ return result, max_pts
diff --git a/code/pytorchvideo/pytorchvideo/data/encoded_video_torchvision.py b/code/pytorchvideo/pytorchvideo/data/encoded_video_torchvision.py
new file mode 100644
index 0000000000000000000000000000000000000000..eee8f17a6b08546f59415c619a64e5fd891b1316
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/encoded_video_torchvision.py
@@ -0,0 +1,276 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+from fractions import Fraction
+from typing import BinaryIO, Dict, Optional
+
+import numpy as np
+import torch
+
+from .utils import pts_to_secs, secs_to_pts, thwc_to_cthw
+from .video import Video
+
+
+logger = logging.getLogger(__name__)
+
+
+class EncodedVideoTorchVision(Video):
+ """
+
+ Accessing clips from an encoded video using Torchvision video reading API
+ (torch.ops.video_reader.read_video_from_memory) as the decoding backend.
+ """
+
+ """
+ av_seek_frame is imprecise so seek to a timestamp earlier by a margin
+ The unit of margin is second
+ """
+ SEEK_FRAME_MARGIN = 0.25
+
+ def __init__(
+ self,
+ file: BinaryIO,
+ video_name: Optional[str] = None,
+ decode_video: bool = True,
+ decode_audio: bool = True,
+ ) -> None:
+ if not decode_video:
+ raise NotImplementedError()
+
+ self._video_tensor = torch.tensor(
+ np.frombuffer(file.getvalue(), dtype=np.uint8)
+ )
+ self._video_name = video_name
+ self._decode_audio = decode_audio
+
+ (
+ self._video,
+ self._video_time_base,
+ self._video_start_pts,
+ video_duration,
+ self._audio,
+ self._audio_time_base,
+ self._audio_start_pts,
+ audio_duration,
+ ) = self._torch_vision_decode_video()
+
+ # Take the largest duration of either video or duration stream.
+ if audio_duration is not None and video_duration is not None:
+ self._duration = max(
+ pts_to_secs(
+ video_duration, self._video_time_base, self._video_start_pts
+ ),
+ pts_to_secs(
+ audio_duration, self._audio_time_base, self._audio_start_pts
+ ),
+ )
+ elif video_duration is not None:
+ self._duration = pts_to_secs(
+ video_duration, self._video_time_base, self._video_start_pts
+ )
+
+ elif audio_duration is not None:
+ self._duration = pts_to_secs(
+ audio_duration, self._audio_time_base, self._audio_start_pts
+ )
+
+ @property
+ def name(self) -> Optional[str]:
+ """
+ Returns:
+ name: the name of the stored video if set.
+ """
+ return self._video_name
+
+ @property
+ def duration(self) -> float:
+ """
+ Returns:
+ duration: the video's duration/end-time in seconds.
+ """
+ return self._duration
+
+ def close(self):
+ pass
+
+ def get_clip(
+ self, start_sec: float, end_sec: float
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """
+ Retrieves frames from the encoded video at the specified start and end times
+ in seconds (the video always starts at 0 seconds). Returned frames will be in
+ [start_sec, end_sec). Note that 1) if you want to avoid float precision issue
+ and need accurate frames, please use Fraction for start_sec and end_sec.
+ 2) As end_sec is exclusive, so you may need to use
+ `get_clip(start_sec, duration + EPS)` to get the last frame.
+
+ Args:
+ start_sec (float): the clip start time in seconds
+ end_sec (float): the clip end time in seconds
+ Returns:
+ clip_data:
+ A dictionary mapping the entries at "video" and "audio" to a tensors.
+
+ "video": A tensor of the clip's RGB frames with shape:
+ (channel, time, height, width). The frames are of type torch.float32 and
+ in the range [0 - 255].
+
+ "audio": A tensor of the clip's audio samples with shape:
+ (samples). The samples are of type torch.float32 and
+ in the range [0 - 255].
+
+ Returns None if no video or audio found within time range.
+
+ """
+ video_frames = None
+ if self._video is not None:
+ video_start_pts = secs_to_pts(
+ start_sec,
+ self._video_time_base,
+ self._video_start_pts,
+ round_mode="ceil",
+ )
+ video_end_pts = secs_to_pts(
+ end_sec,
+ self._video_time_base,
+ self._video_start_pts,
+ round_mode="ceil",
+ )
+ video_frames = [
+ f
+ for f, pts in self._video
+ if pts >= video_start_pts and pts < video_end_pts
+ ]
+
+ audio_samples = None
+ if self._decode_audio and self._audio:
+ audio_start_pts = secs_to_pts(
+ start_sec,
+ self._audio_time_base,
+ self._audio_start_pts,
+ round_mode="ceil",
+ )
+ audio_end_pts = secs_to_pts(
+ end_sec,
+ self._audio_time_base,
+ self._audio_start_pts,
+ round_mode="ceil",
+ )
+ audio_samples = [
+ f
+ for f, pts in self._audio
+ if pts >= audio_start_pts and pts < audio_end_pts
+ ]
+ audio_samples = torch.cat(audio_samples, axis=0)
+ audio_samples = audio_samples.to(torch.float32)
+
+ if video_frames is None or len(video_frames) == 0:
+ logger.warning(
+ f"No video found within {start_sec} and {end_sec} seconds. "
+ f"Video starts at time 0 and ends at {self.duration}."
+ )
+
+ video_frames = None
+
+ if video_frames is not None:
+ video_frames = thwc_to_cthw(torch.stack(video_frames)).to(torch.float32)
+
+ return {
+ "video": video_frames,
+ "audio": audio_samples,
+ }
+
+ def _torch_vision_decode_video(
+ self, start_pts: int = 0, end_pts: int = -1
+ ) -> float:
+ """
+ Decode the video in the PTS range [start_pts, end_pts]
+ """
+ video_and_pts = None
+ audio_and_pts = None
+
+ width, height, min_dimension, max_dimension = 0, 0, 0, 0
+ video_start_pts, video_end_pts = start_pts, end_pts
+ video_timebase_num, video_timebase_den = 0, 1
+
+ samples, channels = 0, 0
+ audio_start_pts, audio_end_pts = start_pts, end_pts
+ audio_timebase_num, audio_timebase_den = 0, 1
+
+ try:
+ tv_result = torch.ops.video_reader.read_video_from_memory(
+ self._video_tensor,
+ self.SEEK_FRAME_MARGIN,
+ # Set getPtsOnly=0, i.e., read full video rather than just header
+ 0,
+ # Read video stream
+ 1,
+ width,
+ height,
+ min_dimension,
+ max_dimension,
+ video_start_pts,
+ video_end_pts,
+ video_timebase_num,
+ video_timebase_den,
+ # Read audio stream
+ self._decode_audio,
+ samples,
+ channels,
+ audio_start_pts,
+ audio_end_pts,
+ audio_timebase_num,
+ audio_timebase_den,
+ )
+ except Exception as e:
+ logger.warning(f"Failed to decode video of name {self._video_name}. {e}")
+ raise e
+
+ (
+ vframes,
+ vframes_pts,
+ vtimebase,
+ _,
+ vduration,
+ aframes,
+ aframe_pts,
+ atimebase,
+ _,
+ aduration,
+ ) = tv_result
+
+ if vduration < 0:
+ # No header information to infer video duration
+ video_duration = int(vframes_pts[-1])
+ else:
+ video_duration = int(vduration)
+
+ video_and_pts = list(zip(vframes, vframes_pts))
+ video_start_pts = int(vframes_pts[0])
+ video_time_base = Fraction(int(vtimebase[0]), int(vtimebase[1]))
+
+ audio_and_pts = None
+ audio_time_base = None
+ audio_start_pts = None
+ audio_duration = None
+ if self._decode_audio:
+ if aduration < 0:
+ # No header information to infer audio duration
+ audio_duration = int(aframe_pts[-1])
+ else:
+ audio_duration = int(aduration)
+
+ audio_and_pts = list(zip(aframes, aframe_pts))
+ audio_start_pts = int(aframe_pts[0])
+ audio_time_base = Fraction(int(atimebase[0]), int(atimebase[1]))
+
+ return (
+ video_and_pts,
+ video_time_base,
+ video_start_pts,
+ video_duration,
+ audio_and_pts,
+ audio_time_base,
+ audio_start_pts,
+ audio_duration,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__init__.py b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dea20a04a5856581db7978504c48993ab5c6faa5
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .epic_kitchen_dataset import ActionData, EpicKitchenDataset
diff --git a/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__pycache__/__init__.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8016c6b89ad4061fe3142b06f0e19fec98d091d1
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__pycache__/__init__.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__pycache__/epic_kitchen_dataset.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__pycache__/epic_kitchen_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be4763158a4691744379c26c19ea07beb4d07940
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/__pycache__/epic_kitchen_dataset.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6077517b79adee65df6c0ca3d5d20d0f38ddb002
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py
@@ -0,0 +1,205 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import ast
+from dataclasses import dataclass, fields as dataclass_fields
+from typing import Any, Callable, Dict, List, Optional
+
+import torch
+from pytorchvideo.data.dataset_manifest_utils import (
+ EncodedVideoInfo,
+ get_seconds_from_hms_time,
+ VideoClipInfo,
+ VideoDataset,
+ VideoDatasetType,
+ VideoFrameInfo,
+ VideoInfo,
+)
+from pytorchvideo.data.frame_video import FrameVideo
+from pytorchvideo.data.utils import DataclassFieldCaster, load_dataclass_dict_from_csv
+from pytorchvideo.data.video import Video
+
+
+@dataclass
+class ActionData(DataclassFieldCaster):
+ """
+ Class representing an action from the Epic Kitchen dataset.
+ """
+
+ participant_id: str
+ video_id: str
+ narration: str
+ start_timestamp: str
+ stop_timestamp: str
+ start_frame: int
+ stop_frame: int
+ verb: str
+ verb_class: int
+ noun: str
+ noun_class: int
+ all_nouns: list = DataclassFieldCaster.complex_initialized_dataclass_field(
+ ast.literal_eval
+ )
+ all_noun_classes: list = DataclassFieldCaster.complex_initialized_dataclass_field(
+ ast.literal_eval
+ )
+
+ @property
+ def start_time(self) -> float:
+ return get_seconds_from_hms_time(self.start_timestamp)
+
+ @property
+ def stop_time(self) -> float:
+ return get_seconds_from_hms_time(self.stop_timestamp)
+
+
+class EpicKitchenDataset(torch.utils.data.Dataset):
+ """
+ Video dataset for EpicKitchen-55 Dataset
+
+
+ This dataset handles the loading, decoding, and configurable clip
+ sampling for the videos.
+ """
+
+ def __init__(
+ self,
+ video_info_file_path: str,
+ actions_file_path: str,
+ clip_sampler: Callable[
+ [Dict[str, Video], Dict[str, List[ActionData]]], List[VideoClipInfo]
+ ],
+ video_data_manifest_file_path: str,
+ dataset_type: VideoDatasetType = VideoDatasetType.Frame,
+ transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
+ frame_filter: Optional[Callable[[List[int]], List[int]]] = None,
+ multithreaded_io: bool = True,
+ ) -> None:
+ f"""
+ Args:
+ video_info_file_path (str):
+ Path or URI to manifest with basic metadata of each video.
+ File must be a csv (w/header) with columns:
+ {[f.name for f in dataclass_fields(VideoInfo)]}
+
+ actions_file_path (str):
+ Path or URI to manifest with action annotations for each video.
+ File must ber a csv (w/header) with columns:
+ {[f.name for f in dataclass_fields(ActionData)]}
+
+ clip_sampler (Callable[[Dict[str, Video]], List[VideoClipInfo]]):
+ This callable takes as input all available videos and outputs a list of clips to
+ be loaded by the dataset.
+
+ video_data_manifest_file_path (str):
+ The path to a json file outlining the available video data for the
+ associated videos. File must be a csv (w/header) with columns:
+ {[f.name for f in dataclass_fields(VideoFrameInfo)]}
+
+ or
+ {[f.name for f in dataclass_fields(EncodedVideoInfo)]}
+
+ To generate this file from a directory of video frames, see helper
+ functions in Module: pytorchvideo.data.epic_kitchen.utils
+
+ dataset_type (VideoDatasetType): The dataformat in which dataset
+ video data is store (e.g. video frames, encoded video etc).
+
+ transform (Optional[Callable[[Dict[str, Any]], Any]]):
+ This callable is evaluated on the clip output before the clip is returned.
+ It can be used for user-defined preprocessing and augmentations to the clips.
+
+ The clip input is a dictionary with the following format:
+ {{
+ 'video': ,
+ 'audio': ,
+ 'actions': ,
+ 'start_time': ,
+ 'stop_time':
+ }}
+
+ If transform is None, the raw clip output in the above format is
+ returned unmodified.
+
+ frame_filter (Optional[Callable[[List[int]], List[int]]]):
+ This callable is evaluated on the set of available frame inidices to be
+ included in a sampled clip. This can be used to subselect frames within
+ a clip to be loaded.
+
+ multithreaded_io (bool):
+ Boolean to control whether parllelizable io operations are performed across
+ multiple threads.
+
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.EpicKitchenDataset.__init__")
+
+ assert video_info_file_path
+ assert actions_file_path
+ assert video_data_manifest_file_path
+ assert clip_sampler
+
+ # Populate video and metadata data providers
+ self._videos: Dict[str, Video] = VideoDataset._load_videos(
+ video_data_manifest_file_path,
+ video_info_file_path,
+ multithreaded_io,
+ dataset_type,
+ )
+
+ self._actions: Dict[str, List[ActionData]] = load_dataclass_dict_from_csv(
+ actions_file_path, ActionData, "video_id", list_per_key=True
+ )
+ # Sample datapoints
+ self._clips: List[VideoClipInfo] = clip_sampler(self._videos, self._actions)
+
+ self._transform = transform
+ self._frame_filter = frame_filter
+
+ def __getitem__(self, index) -> Dict[str, Any]:
+ """
+ Samples a video clip associated to the given index.
+
+ Args:
+ index (int): index for the video clip.
+
+ Returns:
+ A video clip with the following format if transform is None:
+ {{
+ 'video_id': ,
+ 'video': ,
+ 'audio': ,
+ 'actions': ,
+ 'start_time': ,
+ 'stop_time':
+ }}
+ Otherwise, the transform defines the clip output.
+ """
+ clip = self._clips[index]
+ video = self._videos[clip.video_id]
+
+ if isinstance(video, FrameVideo):
+ clip_dict = video.get_clip(
+ clip.start_time, clip.stop_time, self._frame_filter
+ )
+ else:
+ clip_dict = video.get_clip(clip.start_time, clip.stop_time)
+
+ clip_data = {
+ "video_id": clip.video_id,
+ **clip_dict,
+ "actions": self._actions[clip.video_id],
+ "start_time": clip.start_time,
+ "stop_time": clip.stop_time,
+ }
+
+ if self._transform:
+ clip_data = self._transform(clip_data)
+
+ return clip_data
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ The number of video clips in the dataset.
+ """
+ return len(self._clips)
diff --git a/code/pytorchvideo/pytorchvideo/data/epic_kitchen/utils.py b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dedff09199ab7ee3e47f9d434b18c9a0c3eaf9a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/epic_kitchen/utils.py
@@ -0,0 +1,195 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Dict
+
+from iopath.common.file_io import g_pathmgr
+from pytorchvideo.data.dataset_manifest_utils import EncodedVideoInfo, VideoFrameInfo
+from pytorchvideo.data.utils import optional_threaded_foreach
+
+
+def build_frame_manifest_from_flat_directory(
+ data_directory_path: str, multithreaded: bool
+) -> Dict[str, VideoFrameInfo]:
+ """
+ Args:
+ data_directory_path (str): Path or URI to EpicKitchenDataset data.
+ Data at this path must be a folder of structure:
+ {
+ "{video_id}": [
+ "frame_{frame_number}.{file_extension}",
+ "frame_{frame_number}.{file_extension}",
+ "frame_{frame_number}.{file_extension}",
+ ...]
+ ...}
+ multithreaded (bool):
+ controls whether io operations are performed across multiple threads.
+
+ Returns:
+ Dictionary mapping video_id of available videos to the locations of their
+ underlying frame files.
+ """
+
+ video_frames = {}
+ video_ids = g_pathmgr.ls(str(data_directory_path))
+
+ def add_video_frames(video_id: str, video_path: str) -> None:
+ video_frame_file_names = sorted(g_pathmgr.ls(video_path))
+ for frame in video_frame_file_names:
+ file_extension = frame.split(".")[-1]
+ frame_name = frame[: -(len(file_extension) + 1)]
+ stem, path_frame_id = frame_name.split("_")
+ if video_id not in video_frames:
+ video_frames[video_id] = VideoFrameInfo(
+ video_id=video_id,
+ location=video_path,
+ frame_file_stem=f"{stem}_",
+ frame_string_length=len(frame_name),
+ min_frame_number=int(path_frame_id),
+ max_frame_number=int(path_frame_id),
+ file_extension=file_extension,
+ )
+ else:
+ video_frame_info = video_frames[video_id]
+ # Check that this new frame is of the same format as other frames for this video
+ # and that it is the next frame in order, if so update the frame info for this
+ # video to reflect there is an additional frame.
+ # We don't need to check video_id or frame_file_stem as they are function of
+ # video_id which is aligned within the dictionary
+ assert video_frame_info.frame_string_length == len(frame_name)
+ assert video_frame_info.location == video_path, (
+ f"Frames for {video_id} found in two paths: "
+ f"{video_frame_info.location} and {video_path}"
+ )
+ assert video_frame_info.max_frame_number + 1 == int(path_frame_id)
+ assert (
+ video_frame_info.file_extension == file_extension
+ ), f"Frames with two different file extensions found for video {video_id}"
+ video_frames[video_id] = VideoFrameInfo(
+ video_id=video_frame_info.video_id,
+ location=video_frame_info.location,
+ frame_file_stem=video_frame_info.frame_file_stem,
+ frame_string_length=video_frame_info.frame_string_length,
+ min_frame_number=video_frame_info.min_frame_number,
+ max_frame_number=int(path_frame_id), # Update
+ file_extension=video_frame_info.file_extension,
+ )
+
+ video_paths = [
+ (video_id, f"{data_directory_path}/{video_id}") for video_id in video_ids
+ ]
+ # Kick off frame indexing for all participants
+ optional_threaded_foreach(add_video_frames, video_paths, multithreaded)
+
+ return video_frames
+
+
+def build_frame_manifest_from_nested_directory(
+ data_directory_path: str, multithreaded: bool
+) -> Dict[str, VideoFrameInfo]:
+ """
+ Args:
+ data_directory_path (str): Path or URI to EpicKitchenDataset data.
+ If this dataset is to load from the frame-based dataset:
+ Data at this path must be a folder of structure:
+ {
+ "{participant_id}" : [
+ "{participant_id}_{participant_video_id}_{frame_number}.{file_extension}",
+
+ ...],
+ ...}
+
+ multithreaded (bool):
+ controls whether io operations are performed across multiple threads.
+
+ Returns:
+ Dictionary mapping video_id of available videos to the locations of their
+ underlying frame files.
+ """
+
+ participant_ids = g_pathmgr.ls(str(data_directory_path))
+ video_frames = {}
+
+ # Create function to execute in parallel that lists files available for each participant
+ def add_participant_video_frames(
+ participant_id: str, participant_path: str
+ ) -> None:
+ participant_frames = sorted(g_pathmgr.ls(str(participant_path)))
+ for frame_file_name in participant_frames:
+ file_extension = frame_file_name.split(".")[-1]
+ frame_name = frame_file_name[: -(len(file_extension) + 1)]
+ [path_participant_id, path_video_id, path_frame_id] = frame_name.split("_")
+ assert path_participant_id == participant_id
+ video_id = f"{path_participant_id}_{path_video_id}"
+ if (
+ video_id not in video_frames
+ ): # This is the first frame we have seen from video w/ video_id
+ video_frames[video_id] = VideoFrameInfo(
+ video_id=video_id,
+ location=participant_path,
+ frame_file_stem=f"{video_id}_",
+ frame_string_length=len(frame_name),
+ min_frame_number=int(path_frame_id),
+ max_frame_number=int(path_frame_id),
+ file_extension=file_extension,
+ )
+ else:
+ video_frame_info = video_frames[video_id]
+ # Check that this new frame is of the same format as other frames for this video
+ # and that it is the next frame in order, if so update the frame info for this
+ # video to reflect there is an additional frame.
+ # We don't need to check video_id or frame_file_stem as they are function of
+ # video_id which is aligned within the dictionary
+ assert video_frame_info.frame_string_length == len(frame_name)
+ assert video_frame_info.location == participant_path, (
+ f"Frames for {video_id} found in two paths: "
+ f"{video_frame_info.location} and {participant_path}"
+ )
+ assert video_frame_info.max_frame_number + 1 == int(path_frame_id)
+ assert (
+ video_frame_info.file_extension == file_extension
+ ), f"Frames with two different file extensions found for video {video_id}"
+ video_frames[video_id] = VideoFrameInfo(
+ video_id=video_frame_info.video_id,
+ location=video_frame_info.location,
+ frame_file_stem=video_frame_info.frame_file_stem,
+ frame_string_length=video_frame_info.frame_string_length,
+ min_frame_number=video_frame_info.min_frame_number,
+ max_frame_number=int(path_frame_id), # Update
+ file_extension=video_frame_info.file_extension,
+ )
+
+ particpant_paths = [
+ (participant_id, f"{data_directory_path}/{participant_id}")
+ for participant_id in participant_ids
+ ]
+ # Kick off frame indexing for all participants
+ optional_threaded_foreach(
+ add_participant_video_frames, particpant_paths, multithreaded
+ )
+
+ return video_frames
+
+
+def build_encoded_manifest_from_nested_directory(
+ data_directory_path: str,
+) -> Dict[str, EncodedVideoInfo]:
+ """
+ Creates a dictionary from video_id to EncodedVideoInfo for
+ encoded videos in the given directory.
+
+ Args:
+ data_directory_path (str): The folder to ls to find encoded
+ video files.
+
+ Returns:
+ Dict[str, EncodedVideoInfo] mapping video_id to EncodedVideoInfo
+ for each file in 'data_directory_path'
+ """
+ encoded_video_infos = {}
+ for participant_id in g_pathmgr.ls(data_directory_path):
+ participant_folder_path = f"{data_directory_path}/{participant_id}"
+ for video_file_name in g_pathmgr.ls(participant_folder_path):
+ video_id = video_file_name[:6]
+ video_full_path = f"{participant_folder_path}/{video_file_name}"
+ encoded_video_infos[video_id] = EncodedVideoInfo(video_id, video_full_path)
+ return encoded_video_infos
diff --git a/code/pytorchvideo/pytorchvideo/data/epic_kitchen_forecasting.py b/code/pytorchvideo/pytorchvideo/data/epic_kitchen_forecasting.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a6ad5e6d8172aac8977abd5b338b68d83102f20
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/epic_kitchen_forecasting.py
@@ -0,0 +1,295 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from dataclasses import fields as dataclass_fields
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional
+
+import torch
+from pytorchvideo.data.dataset_manifest_utils import (
+ EncodedVideoInfo,
+ VideoClipInfo,
+ VideoDatasetType,
+ VideoFrameInfo,
+ VideoInfo,
+)
+from pytorchvideo.data.epic_kitchen import ActionData, EpicKitchenDataset
+from pytorchvideo.data.video import Video
+
+
+class ClipSampling(Enum):
+ Random = 1
+
+
+class EpicKitchenForecasting(EpicKitchenDataset):
+ """
+ Action forecasting video data set for EpicKitchen-55 Dataset.
+
+
+ This dataset handles the loading, decoding, and clip sampling for the videos.
+ """
+
+ def __init__(
+ self,
+ video_info_file_path: str,
+ actions_file_path: str,
+ video_data_manifest_file_path: str,
+ clip_sampling: ClipSampling = ClipSampling.Random,
+ dataset_type: VideoDatasetType = VideoDatasetType.Frame,
+ seconds_per_clip: float = 2.0,
+ clip_time_stride: float = 10.0,
+ num_input_clips: int = 1,
+ frames_per_clip: Optional[int] = None,
+ num_forecast_actions: int = 1,
+ transform: Callable[[Dict[str, Any]], Any] = None,
+ multithreaded_io: bool = True,
+ ):
+ f"""
+ Args:
+ video_info_file_path (str):
+ Path or URI to manifest with basic metadata of each video.
+ File must be a csv (w/header) with columns:
+ {[f.name for f in dataclass_fields(VideoInfo)]}
+
+ actions_file_path (str):
+ Path or URI to manifest with action annotations for each video.
+ File must ber a csv (w/header) with columns:
+ {[f.name for f in dataclass_fields(ActionData)]}
+
+ video_data_manifest_file_path (str):
+ The path to a json file outlining the available video data for the
+ associated videos. File must be a csv (w/header) with columns either:
+
+ For Frame Videos:
+ {[f.name for f in dataclass_fields(VideoFrameInfo)]}
+
+ For Encoded Videos:
+ {[f.name for f in dataclass_fields(EncodedVideoInfo)]}
+
+ To generate this file from a directory of video frames, see helper
+ functions in Module: pytorchvideo.data.epic_kitchen.utils
+
+ clip_sampling (ClipSampling):
+ The type of sampling to perform to perform on the videos of the dataset.
+
+ dataset_type (VideoDatasetType): The dataformat in which dataset
+ video data is store (e.g. video frames, encoded video etc).
+
+ seconds_per_clip (float): The length of each sampled subclip in seconds.
+
+ clip_time_stride (float): The time difference in seconds between the start of
+ each input subclip.
+
+ num_input_clips (int): The number of subclips to be included in the input
+ video data.
+
+ frames_per_clip (Optional[int]): The number of frames per clip to sample.
+ If None, all frames in the clip will be included.
+
+ num_forecast_actions (int): The number of actions to be included in the
+ action vector.
+
+ transform (Callable[[Dict[str, Any]], Any]):
+ This callable is evaluated on the clip output before the clip is returned.
+ It can be used for user-defined preprocessing and augmentations to the clips.
+ The clip input is a dictionary with the following format:
+ {{
+ 'video_id': ,
+ 'video': ,
+ 'audio': ,
+ 'label': ,
+ 'start_time': ,
+ 'stop_time':
+ }}
+
+ If transform is None, the raw clip output in the above format is
+ returned unmodified.
+
+ multithreaded_io (bool):
+ Boolean to control whether parllelizable io operations are performed across
+ multiple threads.
+ """
+ define_clip_structure_fn = (
+ EpicKitchenForecasting._define_clip_structure_generator(
+ clip_sampling,
+ seconds_per_clip,
+ clip_time_stride,
+ num_input_clips,
+ num_forecast_actions,
+ )
+ )
+ frame_filter = (
+ EpicKitchenForecasting._frame_filter_generator(
+ frames_per_clip, seconds_per_clip, clip_time_stride, num_input_clips
+ )
+ if frames_per_clip is not None
+ else None
+ )
+ transform = EpicKitchenForecasting._transform_generator(
+ transform, num_forecast_actions, frames_per_clip, num_input_clips
+ )
+
+ super().__init__(
+ video_info_file_path=video_info_file_path,
+ actions_file_path=actions_file_path,
+ video_data_manifest_file_path=video_data_manifest_file_path,
+ dataset_type=dataset_type,
+ transform=transform,
+ frame_filter=frame_filter,
+ clip_sampler=define_clip_structure_fn,
+ multithreaded_io=multithreaded_io,
+ )
+
+ @staticmethod
+ def _transform_generator(
+ transform: Callable[[Dict[str, Any]], Dict[str, Any]],
+ num_forecast_actions: int,
+ frames_per_clip: int,
+ num_input_clips: int,
+ ) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
+ """
+ Args:
+ transform (Callable[[Dict[str, Any]], Dict[str, Any]]): A function that performs
+ any operation on a clip before it is returned in the default transform function.
+ num_forecast_actions: (int) The number of actions to be included in the
+ action vector.
+ frames_per_clip (int): The number of frames per clip to sample.
+ num_input_clips (int): The number of subclips to be included in the video data.
+
+ Returns:
+ A function that performs any operation on a clip and returns the transformed clip.
+ """
+
+ def transform_clip(clip: Dict[str, Any]) -> Dict[str, Any]:
+ assert all(
+ clip["actions"][i].start_time <= clip["actions"][i + 1].start_time
+ for i in range(len(clip["actions"]) - 1)
+ ), "Actions must be sorted"
+ next_k_actions: List[ActionData] = [
+ a for a in clip["actions"] if (a.start_time > clip["stop_time"])
+ ][:num_forecast_actions]
+ clip["actions"] = next_k_actions
+
+ assert clip["video"].size()[1] == num_input_clips * frames_per_clip
+ clip_video_tensor = torch.stack(
+ [
+ clip["video"][
+ :, (i * frames_per_clip) : ((i + 1) * frames_per_clip), :, :
+ ]
+ for i in range(num_input_clips)
+ ]
+ )
+ clip["video"] = clip_video_tensor
+
+ for key in clip:
+ if clip[key] is None:
+ clip[key] = torch.tensor([])
+
+ if transform:
+ clip = transform(clip)
+
+ return clip
+
+ return transform_clip
+
+ @staticmethod
+ def _frame_filter_generator(
+ frames_per_clip: int,
+ seconds_per_clip: float,
+ clip_time_stride: float,
+ num_input_clips: int,
+ ) -> Callable[[List[int]], List[int]]:
+ """
+ Args:
+ frames_per_clip (int): The number of frames per clip to sample.
+ seconds_per_clip (float): The length of each sampled subclip in seconds.
+ clip_time_stride (float): The time difference in seconds between the start of
+ each input subclip.
+ num_input_clips (int): The number of subclips to be included in the video data.
+
+ Returns:
+ A function that takes in a list of frame indicies and outputs a subsampled list.
+ """
+ time_window_length = seconds_per_clip + (num_input_clips - 1) * clip_time_stride
+ desired_frames_per_second = frames_per_clip / seconds_per_clip
+
+ def frame_filter(frame_indices: List[int]) -> List[int]:
+ num_available_frames_for_all_clips = len(frame_indices)
+ available_frames_per_second = (
+ num_available_frames_for_all_clips / time_window_length
+ )
+ intra_clip_sampling_stride = int(
+ available_frames_per_second // desired_frames_per_second
+ )
+ selected_frames = set()
+ for i in range(num_input_clips):
+ clip_start_index = int(
+ i * clip_time_stride * available_frames_per_second
+ )
+ for j in range(frames_per_clip):
+ selected_frames.add(
+ clip_start_index + j * intra_clip_sampling_stride
+ )
+ return [x for i, x in enumerate(frame_indices) if i in selected_frames]
+
+ return frame_filter
+
+ @staticmethod
+ def _define_clip_structure_generator(
+ clip_sampling: str,
+ seconds_per_clip: float,
+ clip_time_stride: float,
+ num_input_clips: int,
+ num_forecast_actions: int,
+ ) -> Callable[[Dict[str, Video], Dict[str, List[ActionData]]], List[VideoClipInfo]]:
+ """
+ Args:
+ clip_sampling (ClipSampling):
+ The type of sampling to perform to perform on the videos of the dataset.
+ seconds_per_clip (float): The length of each sampled clip in seconds.
+ clip_time_stride: The time difference in seconds between the start of
+ each input subclip.
+ num_input_clips (int): The number of subclips to be included in the video data.
+ num_forecast_actions (int): The number of actions to be included in the
+ action vector.
+
+ Returns:
+ A function that takes a dictionary of videos and outputs a list of sampled
+ clips.
+ """
+ # TODO(T77683480)
+ if not clip_sampling == ClipSampling.Random:
+ raise NotImplementedError(
+ f"Only {ClipSampling.Random} is implemented. "
+ f"{clip_sampling} not implemented."
+ )
+
+ time_window_length = seconds_per_clip + (num_input_clips - 1) * clip_time_stride
+
+ def define_clip_structure(
+ videos: Dict[str, Video], video_actions: Dict[str, List[ActionData]]
+ ) -> List[VideoClipInfo]:
+ candidate_sample_clips = []
+ for video_id, actions in video_actions.items():
+ for i, action in enumerate(actions[: (-1 * num_forecast_actions)]):
+ # Only actions with num_forecast_actions after to predict
+ # Confirm there are >= num_forecast_actions available
+ # (it is possible for actions to overlap)
+ number_valid_actions = 0
+ for j in range(i + 1, len(actions)):
+ if actions[j].start_time > action.stop_time:
+ number_valid_actions += 1
+ if number_valid_actions == num_forecast_actions:
+ if (
+ action.start_time - time_window_length >= 0
+ ): # Only add clips that have the full input video available
+ candidate_sample_clips.append(
+ VideoClipInfo(
+ video_id,
+ action.stop_time - time_window_length,
+ action.stop_time,
+ )
+ )
+ break
+ return candidate_sample_clips
+
+ return define_clip_structure
diff --git a/code/pytorchvideo/pytorchvideo/data/epic_kitchen_recognition.py b/code/pytorchvideo/pytorchvideo/data/epic_kitchen_recognition.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a6f688e1bcef6ededf07d37185038c2e0a71781
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/epic_kitchen_recognition.py
@@ -0,0 +1,212 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import random
+from dataclasses import fields as dataclass_fields
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional
+
+import torch
+from pytorchvideo.data.dataset_manifest_utils import (
+ EncodedVideoInfo,
+ VideoClipInfo,
+ VideoDatasetType,
+ VideoFrameInfo,
+ VideoInfo,
+)
+from pytorchvideo.data.epic_kitchen import ActionData, EpicKitchenDataset
+from pytorchvideo.data.video import Video
+
+
+class ClipSampling(Enum):
+ RandomOffsetUniform = 1
+
+
+class EpicKitchenRecognition(EpicKitchenDataset):
+ """
+ Action recognition video data set for EpicKitchen-55 Dataset.
+
+
+ This dataset handles the loading, decoding, and clip sampling for the videos.
+ """
+
+ def __init__(
+ self,
+ video_info_file_path: str,
+ actions_file_path: str,
+ video_data_manifest_file_path: str,
+ clip_sampling: ClipSampling = ClipSampling.RandomOffsetUniform,
+ dataset_type: VideoDatasetType = VideoDatasetType.Frame,
+ seconds_per_clip: float = 2.0,
+ frames_per_clip: Optional[int] = None,
+ transform: Callable[[Dict[str, Any]], Any] = None,
+ multithreaded_io: bool = True,
+ ):
+ f"""
+ Args:
+ video_info_file_path (str):
+ Path or URI to manifest with basic metadata of each video.
+ File must be a csv (w/header) with columns:
+ {[f.name for f in dataclass_fields(VideoInfo)]}
+
+ actions_file_path (str):
+ Path or URI to manifest with action annotations for each video.
+ File must ber a csv (w/header) with columns:
+ {[f.name for f in dataclass_fields(ActionData)]}
+
+ video_data_manifest_file_path (str):
+ The path to a json file outlining the available video data for the
+ associated videos. File must be a csv (w/header) with columns either:
+
+ For Frame Videos:
+ {[f.name for f in dataclass_fields(VideoFrameInfo)]}
+
+ For Encoded Videos:
+ {[f.name for f in dataclass_fields(EncodedVideoInfo)]}
+
+ To generate this file from a directory of video frames, see helper
+ functions in Module: pytorchvideo.data.epic_kitchen.utils
+
+ clip_sampling (ClipSampling):
+ The type of sampling to perform to perform on the videos of the dataset.
+
+ dataset_type (VideoDatasetType): The dataformat in which dataset
+ video data is store (e.g. video frames, encoded video etc).
+
+ seconds_per_clip (float): The length of each sampled clip in seconds.
+
+ frames_per_clip (Optional[int]): The number of frames per clip to sample.
+
+ transform (Callable[[Dict[str, Any]], Any]):
+ This callable is evaluated on the clip output before the clip is returned.
+ It can be used for user-defined preprocessing and augmentations to the clips.
+ The clip input is a dictionary with the following format:
+ {{
+ 'video_id': ,
+ 'video': ,
+ 'audio': ,
+ 'label': ,
+ 'start_time': ,
+ 'stop_time':
+ }}
+
+ If transform is None, the raw clip output in the above format is
+ returned unmodified.
+
+ multithreaded_io (bool):
+ Boolean to control whether parllelizable io operations are performed across
+ multiple threads.
+ """
+ define_clip_structure_fn = (
+ EpicKitchenRecognition._define_clip_structure_generator(
+ seconds_per_clip, clip_sampling
+ )
+ )
+ transform = EpicKitchenRecognition._transform_generator(transform)
+ frame_filter = (
+ EpicKitchenRecognition._frame_filter_generator(frames_per_clip)
+ if frames_per_clip is not None
+ else None
+ )
+
+ super().__init__(
+ video_info_file_path=video_info_file_path,
+ actions_file_path=actions_file_path,
+ dataset_type=dataset_type,
+ video_data_manifest_file_path=video_data_manifest_file_path,
+ transform=transform,
+ frame_filter=frame_filter,
+ clip_sampler=define_clip_structure_fn,
+ multithreaded_io=multithreaded_io,
+ )
+
+ @staticmethod
+ def _transform_generator(
+ transform: Callable[[Dict[str, Any]], Dict[str, Any]]
+ ) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
+ """
+ Args:
+ transform (Callable[[Dict[str, Any]], Dict[str, Any]]): A function that performs
+ any operation on a clip before it is returned in the default transform function.
+
+ Returns:
+ A function that performs any operation on a clip and returns the transformed clip.
+ """
+
+ def transform_clip(clip: Dict[str, Any]) -> Dict[str, Any]:
+ actions_in_clip: List[ActionData] = [
+ a
+ for a in clip["actions"]
+ if (
+ a.start_time <= clip["stop_time"]
+ and a.stop_time >= clip["start_time"]
+ )
+ ]
+ clip["actions"] = actions_in_clip
+
+ for key in clip:
+ if clip[key] is None:
+ clip[key] = torch.tensor([])
+
+ if transform:
+ clip = transform(clip)
+
+ return clip
+
+ return transform_clip
+
+ @staticmethod
+ def _frame_filter_generator(
+ frames_per_clip: int,
+ ) -> Callable[[List[int]], List[int]]:
+ """
+ Args:
+ frames_per_clip (int): The number of frames per clip to sample.
+
+ Returns:
+ A function that takes in a list of frame indicies and outputs a subsampled list.
+ """
+
+ def frame_filer(frame_indices: List[int]) -> List[int]:
+ num_frames = len(frame_indices)
+ frame_step = int(num_frames // frames_per_clip)
+ selected_frames = set(range(0, num_frames, frame_step))
+ return [x for i, x in enumerate(frame_indices) if i in selected_frames]
+
+ return frame_filer
+
+ @staticmethod
+ def _define_clip_structure_generator(
+ seconds_per_clip: float, clip_sampling: ClipSampling
+ ) -> Callable[[Dict[str, Video], Dict[str, List[ActionData]]], List[VideoClipInfo]]:
+ """
+ Args:
+ seconds_per_clip (float): The length of each sampled clip in seconds.
+ clip_sampling (ClipSampling):
+ The type of sampling to perform to perform on the videos of the dataset.
+
+ Returns:
+ A function that takes a dictionary of videos and a dictionary of the actions
+ for each video and outputs a list of sampled clips.
+ """
+ if not clip_sampling == ClipSampling.RandomOffsetUniform:
+ raise NotImplementedError(
+ f"Only {ClipSampling.RandomOffsetUniform} is implemented. "
+ f"{clip_sampling} not implemented."
+ )
+
+ def define_clip_structure(
+ videos: Dict[str, Video], actions: Dict[str, List[ActionData]]
+ ) -> List[VideoClipInfo]:
+ clips = []
+ for video_id, video in videos.items():
+ offset = random.random() * seconds_per_clip
+ num_clips = int((video.duration - offset) // seconds_per_clip)
+
+ for i in range(num_clips):
+ start_time = i * seconds_per_clip + offset
+ stop_time = start_time + seconds_per_clip
+ clip = VideoClipInfo(video_id, start_time, stop_time)
+ clips.append(clip)
+ return clips
+
+ return define_clip_structure
diff --git a/code/pytorchvideo/pytorchvideo/data/frame_video.py b/code/pytorchvideo/pytorchvideo/data/frame_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3aacf2f293e5cc138027979aa2d31340c5d6fe1
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/frame_video.py
@@ -0,0 +1,258 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from __future__ import annotations
+
+import logging
+import math
+import os
+import re
+import time
+from typing import Callable, Dict, List, Optional
+
+import numpy as np
+import torch
+import torch.utils.data
+from iopath.common.file_io import g_pathmgr
+from pytorchvideo.data.utils import optional_threaded_foreach
+
+from .utils import thwc_to_cthw
+from .video import Video
+
+
+try:
+ import cv2
+except ImportError:
+ _HAS_CV2 = False
+else:
+ _HAS_CV2 = True
+
+
+logger = logging.getLogger(__name__)
+
+
+class FrameVideo(Video):
+ """
+ FrameVideo is an abstractions for accessing clips based on their start and end
+ time for a video where each frame is stored as an image. PathManager is used for
+ frame image reading, allowing non-local uri's to be used.
+ """
+
+ def __init__(
+ self,
+ duration: float,
+ fps: float,
+ video_frame_to_path_fn: Callable[[int], str] = None,
+ video_frame_paths: List[str] = None,
+ multithreaded_io: bool = False,
+ ) -> None:
+ """
+ Args:
+ duration (float): the duration of the video in seconds.
+ fps (float): the target fps for the video. This is needed to link the frames
+ to a second timestamp in the video.
+ video_frame_to_path_fn (Callable[[int], str]): a function that maps from a frame
+ index integer to the file path where the frame is located.
+ video_frame_paths (List[str]): Dictionary of frame paths for each index of a video.
+ multithreaded_io (bool): controls whether parllelizable io operations are
+ performed across multiple threads.
+ """
+ if not _HAS_CV2:
+ raise ImportError(
+ "opencv2 is required to use FrameVideo. Please "
+ "install with 'pip install opencv-python'"
+ )
+
+ self._duration = duration
+ self._fps = fps
+ self._multithreaded_io = multithreaded_io
+
+ assert (video_frame_to_path_fn is None) != (
+ video_frame_paths is None
+ ), "Only one of video_frame_to_path_fn or video_frame_paths can be provided"
+ self._video_frame_to_path_fn = video_frame_to_path_fn
+ self._video_frame_paths = video_frame_paths
+
+ # Set the pathname to the parent directory of the first frame.
+ self._name = os.path.basename(
+ os.path.dirname(self._video_frame_to_path(frame_index=0))
+ )
+
+ @classmethod
+ def from_directory(
+ cls,
+ path: str,
+ fps: float = 30.0,
+ multithreaded_io=False,
+ path_order_cache: Optional[Dict[str, List[str]]] = None,
+ ):
+ """
+ Args:
+ path (str): path to frame video directory.
+ fps (float): the target fps for the video. This is needed to link the frames
+ to a second timestamp in the video.
+ multithreaded_io (bool): controls whether parllelizable io operations are
+ performed across multiple threads.
+ path_order_cache (dict): An optional mapping from directory-path to list
+ of frames in the directory in numerical order. Used for speedup by
+ caching the frame paths.
+ """
+ if path_order_cache is not None and path in path_order_cache:
+ return cls.from_frame_paths(path_order_cache[path], fps, multithreaded_io)
+
+ assert g_pathmgr.isdir(path), f"{path} is not a directory"
+ rel_frame_paths = g_pathmgr.ls(path)
+
+ def natural_keys(text):
+ return [int(c) if c.isdigit() else c for c in re.split("(\d+)", text)]
+
+ rel_frame_paths.sort(key=natural_keys)
+ frame_paths = [os.path.join(path, f) for f in rel_frame_paths]
+ if path_order_cache is not None:
+ path_order_cache[path] = frame_paths
+ return cls.from_frame_paths(frame_paths, fps, multithreaded_io)
+
+ @classmethod
+ def from_frame_paths(
+ cls,
+ video_frame_paths: List[str],
+ fps: float = 30.0,
+ multithreaded_io: bool = False,
+ ):
+ """
+ Args:
+ video_frame_paths (List[str]): a list of paths to each frames in the video.
+ fps (float): the target fps for the video. This is needed to link the frames
+ to a second timestamp in the video.
+ multithreaded_io (bool): controls whether parllelizable io operations are
+ performed across multiple threads.
+ """
+ assert len(video_frame_paths) != 0, "video_frame_paths is empty"
+ return cls(
+ len(video_frame_paths) / fps,
+ fps,
+ video_frame_paths=video_frame_paths,
+ multithreaded_io=multithreaded_io,
+ )
+
+ @property
+ def name(self) -> float:
+ return self._name
+
+ @property
+ def duration(self) -> float:
+ """
+ Returns:
+ duration: the video's duration/end-time in seconds.
+ """
+ return self._duration
+
+ def _get_frame_index_for_time(self, time_sec: float) -> int:
+ return math.ceil(self._fps * time_sec)
+
+ def get_clip(
+ self,
+ start_sec: float,
+ end_sec: float,
+ frame_filter: Optional[Callable[[List[int]], List[int]]] = None,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """
+ Retrieves frames from the stored video at the specified start and end times
+ in seconds (the video always starts at 0 seconds). Returned frames will be
+ in [start_sec, end_sec). Given that PathManager may
+ be fetching the frames from network storage, to handle transient errors, frame
+ reading is retried N times. Note that as end_sec is exclusive, so you may need
+ to use `get_clip(start_sec, duration + EPS)` to get the last frame.
+
+ Args:
+ start_sec (float): the clip start time in seconds
+ end_sec (float): the clip end time in seconds
+ frame_filter (Optional[Callable[List[int], List[int]]]):
+ function to subsample frames in a clip before loading.
+ If None, no subsampling is peformed.
+ Returns:
+ clip_frames: A tensor of the clip's RGB frames with shape:
+ (channel, time, height, width). The frames are of type torch.float32 and
+ in the range [0 - 255]. Raises an exception if unable to load images.
+
+ clip_data:
+ "video": A tensor of the clip's RGB frames with shape:
+ (channel, time, height, width). The frames are of type torch.float32 and
+ in the range [0 - 255]. Raises an exception if unable to load images.
+
+ "frame_indices": A list of indices for each frame relative to all frames in the
+ video.
+
+ Returns None if no frames are found.
+ """
+ if start_sec < 0 or start_sec > self._duration:
+ logger.warning(
+ f"No frames found within {start_sec} and {end_sec} seconds. Video starts"
+ f"at time 0 and ends at {self._duration}."
+ )
+ return None
+
+ end_sec = min(end_sec, self._duration)
+
+ start_frame_index = self._get_frame_index_for_time(start_sec)
+ end_frame_index = min(
+ self._get_frame_index_for_time(end_sec), len(self._video_frame_paths)
+ )
+ frame_indices = list(range(start_frame_index, end_frame_index))
+ # Frame filter function to allow for subsampling before loading
+ if frame_filter:
+ frame_indices = frame_filter(frame_indices)
+
+ clip_paths = [self._video_frame_to_path(i) for i in frame_indices]
+ clip_frames = _load_images_with_retries(
+ clip_paths, multithreaded=self._multithreaded_io
+ )
+ clip_frames = thwc_to_cthw(clip_frames).to(torch.float32)
+ return {"video": clip_frames, "frame_indices": frame_indices, "audio": None}
+
+ def _video_frame_to_path(self, frame_index: int) -> str:
+ if self._video_frame_to_path_fn:
+ return self._video_frame_to_path_fn(frame_index)
+ elif self._video_frame_paths:
+ return self._video_frame_paths[frame_index]
+ else:
+ raise Exception(
+ "One of _video_frame_to_path_fn or _video_frame_paths must be set"
+ )
+
+
+def _load_images_with_retries(
+ image_paths: List[str], num_retries: int = 10, multithreaded: bool = True
+) -> torch.Tensor:
+ """
+ Loads the given image paths using PathManager, decodes them as RGB images and
+ returns them as a stacked tensors.
+ Args:
+ image_paths (List[str]): a list of paths to images.
+ num_retries (int): number of times to retry image reading to handle transient error.
+ multithreaded (bool): if images are fetched via multiple threads in parallel.
+ Returns:
+ A tensor of the clip's RGB frames with shape:
+ (time, height, width, channel). The frames are of type torch.uint8 and
+ in the range [0 - 255]. Raises an exception if unable to load images.
+ """
+ imgs = [None for i in image_paths]
+
+ def fetch_image(image_index: int, image_path: str) -> None:
+ for i in range(num_retries):
+ with g_pathmgr.open(image_path, "rb") as f:
+ img_str = np.frombuffer(f.read(), np.uint8)
+ img_bgr = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+ if img_rgb is not None:
+ imgs[image_index] = img_rgb
+ return
+ else:
+ logging.warning(f"Reading attempt {i}/{num_retries} failed.")
+ time.sleep(1e-6)
+
+ optional_threaded_foreach(fetch_image, enumerate(image_paths), multithreaded)
+
+ if any((img is None for img in imgs)):
+ raise Exception("Failed to load images from {}".format(image_paths))
+
+ return torch.as_tensor(np.stack(imgs))
diff --git a/code/pytorchvideo/pytorchvideo/data/hmdb51.py b/code/pytorchvideo/pytorchvideo/data/hmdb51.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb87eb3b6d95986138ac298437923c411f1360dd
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/hmdb51.py
@@ -0,0 +1,231 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from __future__ import annotations
+
+import logging
+import os
+import pathlib
+from typing import Any, Callable, List, Optional, Tuple, Type, Union
+
+import torch
+import torch.utils.data
+from iopath.common.file_io import g_pathmgr
+
+from .clip_sampling import ClipSampler
+from .labeled_video_dataset import LabeledVideoDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+class Hmdb51LabeledVideoPaths:
+ """
+ Pre-processor for Hmbd51 dataset mentioned here -
+ https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/
+
+ This dataset consists of classwise folds with each class consisting of 3
+ folds (splits).
+
+ The videos directory is of the format,
+ video_dir_path/class_x/.avi
+ ...
+ video_dir_path/class_y/.avi
+
+ The splits/fold directory is of the format,
+ folds_dir_path/class_x_test_split_1.txt
+ folds_dir_path/class_x_test_split_2.txt
+ folds_dir_path/class_x_test_split_3.txt
+ ...
+ folds_dir_path/class_y_test_split_1.txt
+ folds_dir_path/class_y_test_split_2.txt
+ folds_dir_path/class_y_test_split_3.txt
+
+ And each text file in the splits directory class_x_test_split_<1 or 2 or 3>.txt
+ <0 or 1 or 2>
+ where 0,1,2 corresponds to unused, train split respectively.
+
+ Each video has name of format
+ ______.avi
+ For more details on tags -
+ https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/
+ """
+
+ _allowed_splits = [1, 2, 3]
+ _split_type_dict = {"train": 1, "test": 2, "unused": 0}
+
+ @classmethod
+ def from_dir(
+ cls, data_path: str, split_id: int = 1, split_type: str = "train"
+ ) -> Hmdb51LabeledVideoPaths:
+ """
+ Factory function that creates Hmdb51LabeledVideoPaths object form a splits/folds
+ directory.
+
+ Args:
+ data_path (str): The path to the splits/folds directory of HMDB51.
+ split_id (int): Fold id to be loaded. Belongs to [1,2,3]
+ split_type (str): Split/Fold type to be loaded. It belongs to one of the
+ following,
+ - "train"
+ - "test"
+ - "unused" (This is a small set of videos that are neither
+ of part of test or train fold.)
+ """
+ data_path = pathlib.Path(data_path)
+ if not data_path.is_dir():
+ return RuntimeError(f"{data_path} not found or is not a directory.")
+ if not int(split_id) in cls._allowed_splits:
+ return RuntimeError(
+ f"{split_id} not found in allowed split id's {cls._allowed_splits}."
+ )
+ file_name_format = "_test_split" + str(int(split_id))
+ file_paths = sorted(
+ (
+ f
+ for f in data_path.iterdir()
+ if f.is_file() and f.suffix == ".txt" and file_name_format in f.stem
+ )
+ )
+ return cls.from_csvs(file_paths, split_type)
+
+ @classmethod
+ def from_csvs(
+ cls, file_paths: List[Union[pathlib.Path, str]], split_type: str = "train"
+ ) -> Hmdb51LabeledVideoPaths:
+ """
+ Factory function that creates Hmdb51LabeledVideoPaths object form a list of
+ split files of .txt type
+
+ Args:
+ file_paths (List[Union[pathlib.Path, str]]) : The path to the splits/folds
+ directory of HMDB51.
+ split_type (str): Split/Fold type to be loaded.
+ - "train"
+ - "test"
+ - "unused"
+ """
+ video_paths_and_label = []
+ for file_path in file_paths:
+ file_path = pathlib.Path(file_path)
+ assert g_pathmgr.exists(file_path), f"{file_path} not found."
+ if not (file_path.suffix == ".txt" and "_test_split" in file_path.stem):
+ return RuntimeError(f"Ivalid file: {file_path}")
+
+ action_name = "_"
+ action_name = action_name.join((file_path.stem).split("_")[:-2])
+ with g_pathmgr.open(file_path, "r") as f:
+ for path_label in f.read().splitlines():
+ line_split = path_label.rsplit(None, 1)
+
+ if not int(line_split[1]) == cls._split_type_dict[split_type]:
+ continue
+
+ file_path = os.path.join(action_name, line_split[0])
+ meta_tags = line_split[0].split("_")[-6:-1]
+ video_paths_and_label.append(
+ (file_path, {"label": action_name, "meta_tags": meta_tags})
+ )
+
+ assert (
+ len(video_paths_and_label) > 0
+ ), f"Failed to load dataset from {file_path}."
+ return cls(video_paths_and_label)
+
+ def __init__(
+ self, paths_and_labels: List[Tuple[str, Optional[dict]]], path_prefix=""
+ ) -> None:
+ """
+ Args:
+ paths_and_labels [(str, int)]: a list of tuples containing the video
+ path and integer label.
+ """
+ self._paths_and_labels = paths_and_labels
+ self._path_prefix = path_prefix
+
+ def path_prefix(self, prefix):
+ self._path_prefix = prefix
+
+ path_prefix = property(None, path_prefix)
+
+ def __getitem__(self, index: int) -> Tuple[str, dict]:
+ """
+ Args:
+ index (int): the path and label index.
+
+ Returns:
+ The path and label tuple for the given index.
+ """
+ path, label = self._paths_and_labels[index]
+ return (os.path.join(self._path_prefix, path), label)
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ The number of video paths and label pairs.
+ """
+ return len(self._paths_and_labels)
+
+
+def Hmdb51(
+ data_path: pathlib.Path,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[dict], Any]] = None,
+ video_path_prefix: str = "",
+ split_id: int = 1,
+ split_type: str = "train",
+ decode_audio=True,
+ decoder: str = "pyav",
+) -> LabeledVideoDataset:
+ """
+ A helper function to create ``LabeledVideoDataset`` object for HMDB51 dataset
+
+ Args:
+ data_path (pathlib.Path): Path to the data. The path type defines how the data
+ should be read:
+
+ * For a file path, the file is read and each line is parsed into a
+ video path and label.
+ * For a directory, the directory structure defines the classes
+ (i.e. each subdirectory is a class).
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Callable): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations to the clips. See the ``LabeledVideoDataset`` class for
+ clip output format.
+
+ video_path_prefix (str): Path to root directory with the videos that are
+ loaded in LabeledVideoDataset. All the video paths before loading
+ are prefixed with this path.
+
+ split_id (int): Fold id to be loaded. Options are 1, 2 or 3
+
+ split_type (str): Split/Fold type to be loaded. Options are ("train", "test" or
+ "unused")
+
+ decoder (str): Defines which backend should be used to decode videos.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.Hmdb51")
+
+ labeled_video_paths = Hmdb51LabeledVideoPaths.from_dir(
+ data_path, split_id=split_id, split_type=split_type
+ )
+ labeled_video_paths.path_prefix = video_path_prefix
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler,
+ video_sampler,
+ transform,
+ decode_audio=decode_audio,
+ decoder=decoder,
+ )
+
+ return dataset
diff --git a/code/pytorchvideo/pytorchvideo/data/json_dataset.py b/code/pytorchvideo/pytorchvideo/data/json_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c86c1b51a0edaa6c8199b650dc27d40c7bae08a8
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/json_dataset.py
@@ -0,0 +1,254 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import json
+import logging
+import os
+from typing import Any, Callable, Dict, Optional, Type
+
+import torch
+from iopath.common.file_io import g_pathmgr
+from pytorchvideo.data.clip_sampling import ClipInfo, ClipSampler
+from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+def video_only_dataset(
+ data_path: str,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ video_path_prefix: str = "",
+ decode_audio: bool = True,
+ decoder: str = "pyav",
+):
+ """
+ Builds a LabeledVideoDataset with no annotations from a json file with the following
+ format:
+
+ .. code-block:: text
+
+ {
+ "video_name1": {...}
+ "video_name2": {...}
+ ....
+ "video_nameN": {...}
+ }
+
+ Args:
+ labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing
+ video file paths and associated labels. If video paths are a folder
+ it's interpreted as a frame video, otherwise it must be an encoded
+ video.
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Callable): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations on the clips. The clip output format is described in __next__().
+
+ decode_audio (bool): If True, also decode audio from video.
+
+ decoder (str): Defines what type of decoder used to decode a video. Not used for
+ frame videos.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.json_dataset.video_only_dataset")
+
+ if g_pathmgr.isfile(data_path):
+ try:
+ with g_pathmgr.open(data_path, "r") as f:
+ annotations = json.load(f)
+ except Exception:
+ raise FileNotFoundError(f"{data_path} must be json for Ego4D dataset")
+
+ # LabeledVideoDataset requires the data to be list of tuples with format:
+ # (video_paths, annotation_dict), for no annotations we just pass in an empty dict.
+ video_paths = [
+ (os.path.join(video_path_prefix, x), {}) for x in annotations.keys()
+ ]
+ else:
+ raise FileNotFoundError(f"{data_path} not found.")
+
+ dataset = LabeledVideoDataset(
+ video_paths,
+ clip_sampler,
+ video_sampler,
+ transform,
+ decode_audio=decode_audio,
+ decoder=decoder,
+ )
+ return dataset
+
+
+def clip_recognition_dataset(
+ data_path: str,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ video_path_prefix: str = "",
+ decode_audio: bool = True,
+ decoder: str = "pyav",
+):
+ """
+ Builds a LabeledVideoDataset with noun, verb annotations from a json file with the following
+ format:
+
+ .. code-block:: text
+
+ {
+ "video_name1": {
+ {
+ "benchmarks": {
+ "forecasting_hands_objects": [
+ {
+ "critical_frame_selection_parent_start_sec":
+ "critical_frame_selection_parent_end_sec":
+ {
+ "taxonomy: {
+ "noun": ,
+ "verb": ,
+ }
+ }
+ },
+ {
+ ...
+ }
+ ]
+ }
+ }
+ }
+ "video_name2": {...}
+ ....
+ "video_nameN": {...}
+ }
+
+ Args:
+ labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing
+ video file paths and associated labels. If video paths are a folder
+ it's interpreted as a frame video, otherwise it must be an encoded
+ video.
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Callable): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations on the clips. The clip output format is described in __next__().
+
+ decode_audio (bool): If True, also decode audio from video.
+
+ decoder (str): Defines what type of decoder used to decode a video. Not used for
+ frame videos.
+ """
+ if g_pathmgr.isfile(data_path):
+ try:
+ with g_pathmgr.open(data_path, "r") as f:
+ annotations = json.load(f)
+ except Exception:
+ raise FileNotFoundError(f"{data_path} must be json for Ego4D dataset")
+
+ # LabeledVideoDataset requires the data to be list of tuples with format:
+ # (video_paths, annotation_dict), for no annotations we just pass in an empty dict.
+ untrimmed_clip_annotations = []
+ for video_name, child in annotations.items():
+ video_path = os.path.join(video_path_prefix, video_name)
+ for clip_annotation in child["benchmarks"]["forecasting_hands_objects"]:
+ clip_start = clip_annotation[
+ "critical_frame_selection_parent_start_sec"
+ ]
+ clip_end = clip_annotation["critical_frame_selection_parent_end_sec"]
+ taxonomy = clip_annotation["taxonomy"]
+ noun_label = taxonomy["noun"]
+ verb_label = taxonomy["verb"]
+ verb_unsure = taxonomy["verb_unsure"]
+ noun_unsure = taxonomy["noun_unsure"]
+ if (
+ noun_label is None
+ or verb_label is None
+ or verb_unsure
+ or noun_unsure
+ ):
+ continue
+
+ untrimmed_clip_annotations.append(
+ (
+ video_path,
+ {
+ "clip_start_sec": clip_start,
+ "clip_end_sec": clip_end,
+ "noun_label": noun_label,
+ "verb_label": verb_label,
+ },
+ )
+ )
+ else:
+ raise FileNotFoundError(f"{data_path} not found.")
+
+ # Map noun and verb key words to unique index.
+ def map_labels_to_index(label_name):
+ labels = list({info[label_name] for _, info in untrimmed_clip_annotations})
+ label_to_idx = {label: i for i, label in enumerate(labels)}
+ for i in range(len(untrimmed_clip_annotations)):
+ label = untrimmed_clip_annotations[i][1][label_name]
+ untrimmed_clip_annotations[i][1][label_name] = label_to_idx[label]
+
+ map_labels_to_index("noun_label")
+ map_labels_to_index("verb_label")
+
+ dataset = LabeledVideoDataset(
+ untrimmed_clip_annotations,
+ UntrimmedClipSampler(clip_sampler),
+ video_sampler,
+ transform,
+ decode_audio=decode_audio,
+ decoder=decoder,
+ )
+ return dataset
+
+
+class UntrimmedClipSampler:
+ """
+ A wrapper for adapting untrimmed annotated clips from the json_dataset to the
+ standard `pytorchvideo.data.ClipSampler` expected format. Specifically, for each
+ clip it uses the provided `clip_sampler` to sample between "clip_start_sec" and
+ "clip_end_sec" from the json_dataset clip annotation.
+ """
+
+ def __init__(self, clip_sampler: ClipSampler) -> None:
+ """
+ Args:
+ clip_sampler (`pytorchvideo.data.ClipSampler`): Strategy used for sampling
+ between the untrimmed clip boundary.
+ """
+ self._trimmed_clip_sampler = clip_sampler
+
+ def __call__(
+ self, last_clip_time: float, video_duration: float, clip_info: Dict[str, Any]
+ ) -> ClipInfo:
+ clip_start_boundary = clip_info["clip_start_sec"]
+ clip_end_boundary = clip_info["clip_end_sec"]
+ duration = clip_start_boundary - clip_end_boundary
+
+ # Sample between 0 and duration of untrimmed clip, then add back start boundary.
+ clip_info = self._trimmed_clip_sampler(last_clip_time, duration, clip_info)
+ return ClipInfo(
+ clip_info.clip_start_sec + clip_start_boundary,
+ clip_info.clip_end_sec + clip_start_boundary,
+ clip_info.clip_index,
+ clip_info.aug_index,
+ clip_info.is_last_clip,
+ )
+
+ def reset(self) -> None:
+ pass
diff --git a/code/pytorchvideo/pytorchvideo/data/kinetics.py b/code/pytorchvideo/pytorchvideo/data/kinetics.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cdb36463ea3d8a909355c314210260d8fc590fb
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/kinetics.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Callable, Dict, Optional, Type
+
+import torch
+from pytorchvideo.data.clip_sampling import ClipSampler
+
+from .labeled_video_dataset import labeled_video_dataset, LabeledVideoDataset
+
+
+"""
+ Action recognition video dataset for Kinetics-{400,600,700}
+
+"""
+
+
+def Kinetics(
+ data_path: str,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ video_path_prefix: str = "",
+ decode_audio: bool = True,
+ decoder: str = "pyav",
+) -> LabeledVideoDataset:
+ """
+ A helper function to create ``LabeledVideoDataset`` object for the Kinetics dataset.
+
+ Args:
+ data_path (str): Path to the data. The path type defines how the data
+ should be read:
+
+ * For a file path, the file is read and each line is parsed into a
+ video path and label.
+ * For a directory, the directory structure defines the classes
+ (i.e. each subdirectory is a class).
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Callable): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations to the clips. See the ``LabeledVideoDataset`` class for clip
+ output format.
+
+ video_path_prefix (str): Path to root directory with the videos that are
+ loaded in ``LabeledVideoDataset``. All the video paths before loading
+ are prefixed with this path.
+
+ decode_audio (bool): If True, also decode audio from video.
+
+ decoder (str): Defines what type of decoder used to decode a video.
+
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.Kinetics")
+
+ return labeled_video_dataset(
+ data_path,
+ clip_sampler,
+ video_sampler,
+ transform,
+ video_path_prefix,
+ decode_audio,
+ decoder,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/data/labeled_video_dataset.py b/code/pytorchvideo/pytorchvideo/data/labeled_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0b3646db922f4f8ad65e45ecf8de42d8c7b1f27
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/labeled_video_dataset.py
@@ -0,0 +1,306 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from __future__ import annotations
+
+import gc
+import logging
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+
+import torch.utils.data
+from pytorchvideo.data.clip_sampling import ClipSampler
+from pytorchvideo.data.video import VideoPathHandler
+
+from .labeled_video_paths import LabeledVideoPaths
+from .utils import MultiProcessSampler
+
+
+logger = logging.getLogger(__name__)
+
+
+class LabeledVideoDataset(torch.utils.data.IterableDataset):
+ """
+ LabeledVideoDataset handles the storage, loading, decoding and clip sampling for a
+ video dataset. It assumes each video is stored as either an encoded video
+ (e.g. mp4, avi) or a frame video (e.g. a folder of jpg, or png)
+ """
+
+ _MAX_CONSECUTIVE_FAILURES = 10
+
+ def __init__(
+ self,
+ labeled_video_paths: List[Tuple[str, Optional[dict]]],
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[dict], Any]] = None,
+ decode_audio: bool = True,
+ decode_video: bool = True,
+ decoder: str = "pyav",
+ ) -> None:
+ """
+ Args:
+ labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing
+ video file paths and associated labels. If video paths are a folder
+ it's interpreted as a frame video, otherwise it must be an encoded
+ video.
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Callable): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations on the clips. The clip output format is described in __next__().
+
+ decode_audio (bool): If True, decode audio from video.
+
+ decode_video (bool): If True, decode video frames from a video container.
+
+ decoder (str): Defines what type of decoder used to decode a video. Not used for
+ frame videos.
+ """
+ self._decode_audio = decode_audio
+ self._decode_video = decode_video
+ self._transform = transform
+ self._clip_sampler = clip_sampler
+ self._labeled_videos = labeled_video_paths
+ self._decoder = decoder
+
+ # If a RandomSampler is used we need to pass in a custom random generator that
+ # ensures all PyTorch multiprocess workers have the same random seed.
+ self._video_random_generator = None
+ if video_sampler == torch.utils.data.RandomSampler:
+ self._video_random_generator = torch.Generator()
+ self._video_sampler = video_sampler(
+ self._labeled_videos, generator=self._video_random_generator
+ )
+ else:
+ self._video_sampler = video_sampler(self._labeled_videos)
+
+ self._video_sampler_iter = None # Initialized on first call to self.__next__()
+
+ # Depending on the clip sampler type, we may want to sample multiple clips
+ # from one video. In that case, we keep the store video, label and previous sampled
+ # clip time in these variables.
+ self._loaded_video_label = None
+ self._loaded_clip = None
+ self._last_clip_end_time = None
+ self.video_path_handler = VideoPathHandler()
+
+ @property
+ def video_sampler(self):
+ """
+ Returns:
+ The video sampler that defines video sample order. Note that you'll need to
+ use this property to set the epoch for a torch.utils.data.DistributedSampler.
+ """
+ return self._video_sampler
+
+ @property
+ def num_videos(self):
+ """
+ Returns:
+ Number of videos in dataset.
+ """
+ return len(self.video_sampler)
+
+ def __next__(self) -> dict:
+ """
+ Retrieves the next clip based on the clip sampling strategy and video sampler.
+
+ Returns:
+ A dictionary with the following format.
+
+ .. code-block:: text
+
+ {
+ 'video': ,
+ 'label': ,
+ 'video_label':
+ 'video_index': ,
+ 'clip_index': ,
+ 'aug_index': ,
+ }
+ """
+ if not self._video_sampler_iter:
+ # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned.
+ self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler))
+
+ for i_try in range(self._MAX_CONSECUTIVE_FAILURES):
+ # Reuse previously stored video if there are still clips to be sampled from
+ # the last loaded video.
+ if self._loaded_video_label:
+ video, info_dict, video_index = self._loaded_video_label
+ else:
+ video_index = next(self._video_sampler_iter)
+ try:
+ video_path, info_dict = self._labeled_videos[video_index]
+ video = self.video_path_handler.video_from_path(
+ video_path,
+ decode_audio=self._decode_audio,
+ decode_video=self._decode_video,
+ decoder=self._decoder,
+ )
+ self._loaded_video_label = (video, info_dict, video_index)
+ except Exception as e:
+ logger.debug(
+ "Failed to load video with error: {}; trial {}".format(
+ e,
+ i_try,
+ )
+ )
+ logger.exception("Video load exception")
+ continue
+
+ (
+ clip_start,
+ clip_end,
+ clip_index,
+ aug_index,
+ is_last_clip,
+ ) = self._clip_sampler(self._last_clip_end_time, video.duration, info_dict)
+
+ if isinstance(clip_start, list): # multi-clip in each sample
+
+ # Only load the clips once and reuse previously stored clips if there are multiple
+ # views for augmentations to perform on the same clips.
+ if aug_index[0] == 0:
+ self._loaded_clip = {}
+ loaded_clip_list = []
+ for i in range(len(clip_start)):
+ clip_dict = video.get_clip(clip_start[i], clip_end[i])
+ if clip_dict is None or clip_dict["video"] is None:
+ self._loaded_clip = None
+ break
+ loaded_clip_list.append(clip_dict)
+
+ if self._loaded_clip is not None:
+ for key in loaded_clip_list[0].keys():
+ self._loaded_clip[key] = [x[key] for x in loaded_clip_list]
+
+ else: # single clip case
+
+ # Only load the clip once and reuse previously stored clip if there are multiple
+ # views for augmentations to perform on the same clip.
+ if aug_index == 0:
+ self._loaded_clip = video.get_clip(clip_start, clip_end)
+
+ self._last_clip_end_time = clip_end
+
+ video_is_null = (
+ self._loaded_clip is None or self._loaded_clip["video"] is None
+ )
+ if (
+ is_last_clip[-1] if isinstance(is_last_clip, list) else is_last_clip
+ ) or video_is_null:
+ # Close the loaded encoded video and reset the last sampled clip time ready
+ # to sample a new video on the next iteration.
+ self._loaded_video_label[0].close()
+ self._loaded_video_label = None
+ self._last_clip_end_time = None
+ self._clip_sampler.reset()
+
+ # Force garbage collection to release video container immediately
+ # otherwise memory can spike.
+ gc.collect()
+
+ if video_is_null:
+ logger.debug(
+ "Failed to load clip {}; trial {}".format(video.name, i_try)
+ )
+ continue
+
+ frames = self._loaded_clip["video"]
+ audio_samples = self._loaded_clip["audio"]
+ sample_dict = {
+ "video": frames,
+ "video_name": video.name,
+ "video_index": video_index,
+ "clip_index": clip_index,
+ "aug_index": aug_index,
+ **info_dict,
+ **({"audio": audio_samples} if audio_samples is not None else {}),
+ }
+ if self._transform is not None:
+ sample_dict = self._transform(sample_dict)
+
+ # User can force dataset to continue by returning None in transform.
+ if sample_dict is None:
+ continue
+
+ return sample_dict
+ else:
+ raise RuntimeError(
+ f"Failed to load video after {self._MAX_CONSECUTIVE_FAILURES} retries."
+ )
+
+ def __iter__(self):
+ self._video_sampler_iter = None # Reset video sampler
+
+ # If we're in a PyTorch DataLoader multiprocessing context, we need to use the
+ # same seed for each worker's RandomSampler generator. The workers at each
+ # __iter__ call are created from the unique value: worker_info.seed - worker_info.id,
+ # which we can use for this seed.
+ worker_info = torch.utils.data.get_worker_info()
+ if self._video_random_generator is not None and worker_info is not None:
+ base_seed = worker_info.seed - worker_info.id
+ self._video_random_generator.manual_seed(base_seed)
+
+ return self
+
+
+def labeled_video_dataset(
+ data_path: str,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ video_path_prefix: str = "",
+ decode_audio: bool = True,
+ decoder: str = "pyav",
+) -> LabeledVideoDataset:
+ """
+ A helper function to create ``LabeledVideoDataset`` object for Ucf101 and Kinetics datasets.
+
+ Args:
+ data_path (str): Path to the data. The path type defines how the data
+ should be read:
+
+ * For a file path, the file is read and each line is parsed into a
+ video path and label.
+ * For a directory, the directory structure defines the classes
+ (i.e. each subdirectory is a class).
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Callable): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations to the clips. See the ``LabeledVideoDataset`` class for clip
+ output format.
+
+ video_path_prefix (str): Path to root directory with the videos that are
+ loaded in ``LabeledVideoDataset``. All the video paths before loading
+ are prefixed with this path.
+
+ decode_audio (bool): If True, also decode audio from video.
+
+ decoder (str): Defines what type of decoder used to decode a video.
+
+ """
+ labeled_video_paths = LabeledVideoPaths.from_path(data_path)
+ labeled_video_paths.path_prefix = video_path_prefix
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler,
+ video_sampler,
+ transform,
+ decode_audio=decode_audio,
+ decoder=decoder,
+ )
+ return dataset
diff --git a/code/pytorchvideo/pytorchvideo/data/labeled_video_paths.py b/code/pytorchvideo/pytorchvideo/data/labeled_video_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..81009855c81edc743a6efeb0473802092d2d124a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/labeled_video_paths.py
@@ -0,0 +1,141 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from __future__ import annotations
+
+import os
+import pathlib
+from typing import List, Optional, Tuple
+
+from iopath.common.file_io import g_pathmgr
+from torchvision.datasets.folder import make_dataset
+
+
+class LabeledVideoPaths:
+ """
+ LabeledVideoPaths contains pairs of video path and integer index label.
+ """
+
+ @classmethod
+ def from_path(cls, data_path: str) -> LabeledVideoPaths:
+ """
+ Factory function that creates a LabeledVideoPaths object depending on the path
+ type.
+ - If it is a directory path it uses the LabeledVideoPaths.from_directory function.
+ - If it's a file it uses the LabeledVideoPaths.from_csv file.
+ Args:
+ file_path (str): The path to the file to be read.
+ """
+
+ if g_pathmgr.isfile(data_path):
+ return LabeledVideoPaths.from_csv(data_path)
+ elif g_pathmgr.isdir(data_path):
+ return LabeledVideoPaths.from_directory(data_path)
+ else:
+ raise FileNotFoundError(f"{data_path} not found.")
+
+ @classmethod
+ def from_csv(cls, file_path: str) -> LabeledVideoPaths:
+ """
+ Factory function that creates a LabeledVideoPaths object by reading a file with the
+ following format:
+
+ ...
+
+
+ Args:
+ file_path (str): The path to the file to be read.
+ """
+ assert g_pathmgr.exists(file_path), f"{file_path} not found."
+ video_paths_and_label = []
+ with g_pathmgr.open(file_path, "r") as f:
+ for path_label in f.read().splitlines():
+ line_split = path_label.rsplit(None, 1)
+
+ # The video path file may not contain labels (e.g. for a test split). We
+ # assume this is the case if only 1 path is found and set the label to
+ # -1 if so.
+ if len(line_split) == 1:
+ file_path = line_split[0]
+ label = -1
+ else:
+ file_path, label = line_split
+
+ video_paths_and_label.append((file_path, int(label)))
+
+ assert (
+ len(video_paths_and_label) > 0
+ ), f"Failed to load dataset from {file_path}."
+ return cls(video_paths_and_label)
+
+ @classmethod
+ def from_directory(cls, dir_path: str) -> LabeledVideoPaths:
+ """
+ Factory function that creates a LabeledVideoPaths object by parsing the structure
+ of the given directory's subdirectories into the classification labels. It
+ expects the directory format to be the following:
+ dir_path//.mp4
+
+ Classes are indexed from 0 to the number of classes, alphabetically.
+
+ E.g.
+ dir_path/class_x/xxx.ext
+ dir_path/class_x/xxy.ext
+ dir_path/class_x/xxz.ext
+ dir_path/class_y/123.ext
+ dir_path/class_y/nsdf3.ext
+ dir_path/class_y/asd932_.ext
+
+ Would produce two classes labeled 0 and 1 with 3 videos paths associated with each.
+
+ Args:
+ dir_path (str): Root directory to the video class directories .
+ """
+ assert g_pathmgr.exists(dir_path), f"{dir_path} not found."
+
+ # Find all classes based on directory names. These classes are then sorted and indexed
+ # from 0 to the number of classes.
+ classes = sorted(
+ (f.name for f in pathlib.Path(dir_path).iterdir() if f.is_dir())
+ )
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ video_paths_and_label = make_dataset(
+ dir_path, class_to_idx, extensions=("mp4", "avi")
+ )
+ assert (
+ len(video_paths_and_label) > 0
+ ), f"Failed to load dataset from {dir_path}."
+ return cls(video_paths_and_label)
+
+ def __init__(
+ self, paths_and_labels: List[Tuple[str, Optional[int]]], path_prefix=""
+ ) -> None:
+ """
+ Args:
+ paths_and_labels [(str, int)]: a list of tuples containing the video
+ path and integer label.
+ """
+ self._paths_and_labels = paths_and_labels
+ self._path_prefix = path_prefix
+
+ def path_prefix(self, prefix):
+ self._path_prefix = prefix
+
+ path_prefix = property(None, path_prefix)
+
+ def __getitem__(self, index: int) -> Tuple[str, int]:
+ """
+ Args:
+ index (int): the path and label index.
+
+ Returns:
+ The path and label tuple for the given index.
+ """
+ path, label = self._paths_and_labels[index]
+ return (os.path.join(self._path_prefix, path), {"label": label})
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ The number of video paths and label pairs.
+ """
+ return len(self._paths_and_labels)
diff --git a/code/pytorchvideo/pytorchvideo/data/ssv2.py b/code/pytorchvideo/pytorchvideo/data/ssv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2938bed96b8dbe83fbc876f64ab4385dc7db24ee
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/ssv2.py
@@ -0,0 +1,249 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import csv
+import functools
+import json
+import os
+import random
+from collections import defaultdict
+from typing import Any, Callable, List, Optional, Tuple, Type
+
+import numpy as np
+import torch
+import torch.utils.data
+from iopath.common.file_io import g_pathmgr
+from pytorchvideo.data.clip_sampling import ClipSampler
+from pytorchvideo.data.frame_video import FrameVideo
+
+from .utils import MultiProcessSampler
+
+
+class SSv2(torch.utils.data.IterableDataset):
+ """
+ Action recognition video dataset for
+ `Something-something v2 (SSv2) `_ stored
+ as image frames.
+
+ This dataset handles the parsing of frames, loading and clip sampling for the
+ videos. All io is done through :code:`iopath.common.file_io.PathManager`, enabling
+ non-local storage uri's to be used.
+ """
+
+ def __init__(
+ self,
+ label_name_file: str,
+ video_label_file: str,
+ video_path_label_file: str,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[dict], Any]] = None,
+ video_path_prefix: str = "",
+ frames_per_clip: Optional[int] = None,
+ rand_sample_frames: bool = False,
+ ) -> None:
+ """
+ Args:
+ label_name_file (str): SSV2 label file that contains the label names and
+ indexes.
+
+ video_label_file (str): a file that contains video ids and the corresponding
+ video label.
+
+ video_path_label_file (str): a file that contains frame paths for each
+ video and the corresponding frame label. The file must be a space separated
+ csv of the format: (original_vido_id video_id frame_id path labels).
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Optional[Callable]): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations on the clips. The clip output format is described in __next__().
+
+ video_path_prefix (str): prefix path to add to all paths from data_path.
+
+ frames_per_clip (Optional[int]): The number of frames per clip to sample.
+
+ rand_sample_frames (bool): If True, randomly sampling frames for each clip.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.SSv2.__init__")
+
+ self._transform = transform
+ self._clip_sampler = clip_sampler
+ self._path_to_videos, self._labels = _read_video_paths_and_labels(
+ label_name_file,
+ video_label_file,
+ video_path_label_file,
+ prefix=video_path_prefix,
+ )
+ self._video_sampler = video_sampler(self._path_to_videos)
+ self._video_sampler_iter = None # Initialized on first call to self.__next__()
+ self._frame_filter = (
+ functools.partial(
+ SSv2._sample_clip_frames,
+ frames_per_clip=frames_per_clip,
+ rand_sample=rand_sample_frames,
+ )
+ if frames_per_clip is not None
+ else None
+ )
+
+ # Depending on the clip sampler type, we may want to sample multiple clips
+ # from one video. In that case, we keep the store video, label and previous sampled
+ # clip time in these variables.
+ self._loaded_video = None
+ self._next_clip_start_time = 0.0
+
+ @staticmethod
+ def _sample_clip_frames(
+ frame_indices: List[int], frames_per_clip: int, rand_sample: bool
+ ) -> List[int]:
+ """
+ Use segment-based input frame sampling that splits eachvideo into segments,
+ and from each of them, we sample one frame to form a clip.
+
+ Args:
+ frame_indices (list): list of frame indices.
+ frames_per_clip (int): The number of frames per clip to sample.
+ rand_sample (bool): if True, randomly sampling frames.
+
+ Returns:
+ (list): Outputs a subsampled list with num_samples frames.
+ """
+ num_frames = len(frame_indices)
+
+ seg_size = float(num_frames - 1) / frames_per_clip
+ seq = []
+ for i in range(frames_per_clip):
+ start = int(np.round(seg_size * i))
+ end = int(np.round(seg_size * (i + 1)))
+ if rand_sample:
+ seq.append(random.randint(start, end))
+ else:
+ seq.append((start + end) // 2)
+
+ return [frame_indices[idx] for idx in seq]
+
+ @property
+ def video_sampler(self):
+ return self._video_sampler
+
+ def __next__(self) -> dict:
+ """
+ Retrieves the next clip based on the clip sampling strategy and video sampler.
+
+ Returns:
+ A dictionary with the following format.
+
+ .. code-block:: text
+
+ {
+ 'video': ,
+ 'label': ,
+ 'video_label':
+ 'video_index': ,
+ 'clip_index': ,
+ 'aug_index': ,
+ }
+ """
+ if not self._video_sampler_iter:
+ # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned.
+ self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler))
+
+ if self._loaded_video:
+ video, video_index = self._loaded_video
+ else:
+ video_index = next(self._video_sampler_iter)
+ path_to_video_frames = self._path_to_videos[video_index]
+ video = FrameVideo.from_frame_paths(path_to_video_frames)
+ self._loaded_video = (video, video_index)
+
+ clip_start, clip_end, clip_index, aug_index, is_last_clip = self._clip_sampler(
+ self._next_clip_start_time, video.duration, {}
+ )
+ # Only load the clip once and reuse previously stored clip if there are multiple
+ # views for augmentations to perform on the same clip.
+ if aug_index == 0:
+ self._loaded_clip = video.get_clip(0, video.duration, self._frame_filter)
+
+ self._next_clip_start_time = clip_end
+
+ if is_last_clip:
+ self._loaded_video = None
+ self._next_clip_start_time = 0.0
+
+ sample_dict = {
+ "video": self._loaded_clip["video"],
+ "label": self._labels[video_index],
+ "video_name": str(video_index),
+ "video_index": video_index,
+ "clip_index": clip_index,
+ "aug_index": aug_index,
+ }
+ if self._transform is not None:
+ sample_dict = self._transform(sample_dict)
+
+ return sample_dict
+
+ def __iter__(self):
+ return self
+
+
+def _read_video_paths_and_labels(
+ label_name_file: str,
+ video_label_file: str,
+ video_path_label_file: str,
+ prefix: str = "",
+) -> Tuple[List[str], List[int]]:
+ """
+ Args:
+ label_name_file (str): ssv2 label file that contians the label names and
+ indexes. ('/path/to/folder/something-something-v2-labels.json')
+ video_label_file (str): a file that contains video ids and the corresponding
+ video label. (e.g., '/path/to/folder/something-something-v2-train.json')
+ video_path_label_file (str): a file that contains frame paths for each
+ video and the corresponding frame label. The file must be a space separated
+ csv of the format:
+ `original_vido_id video_id frame_id path labels`
+ prefix (str): prefix path to add to all paths from video_path_label_file.
+
+ Returns:
+ image_paths (list): list of list containing path to each frame.
+ labels (list): list containing label of each video.
+ """
+ # Loading image paths.
+ paths = defaultdict(list)
+ with g_pathmgr.open(video_path_label_file, "r") as f:
+ # Space separated CSV with format: original_vido_id video_id frame_id path labels
+ csv_reader = csv.DictReader(f, delimiter=" ")
+ for row in csv_reader:
+ assert len(row) == 5
+ video_name = row["original_vido_id"]
+ path = os.path.join(prefix, row["path"])
+ paths[video_name].append(path)
+
+ # Loading label names.
+ with g_pathmgr.open(label_name_file, "r") as f:
+ label_name_dict = json.load(f)
+
+ with g_pathmgr.open(video_label_file, "r") as f:
+ video_label_json = json.load(f)
+
+ labels = []
+ image_paths = []
+ for video in video_label_json:
+ video_name = video["id"]
+ if video_name in paths:
+ template = video["template"]
+ template = template.replace("[", "")
+ template = template.replace("]", "")
+ label = int(label_name_dict[template])
+ image_paths.append(paths[video_name])
+ labels.append(label)
+
+ return image_paths, labels
diff --git a/code/pytorchvideo/pytorchvideo/data/ucf101.py b/code/pytorchvideo/pytorchvideo/data/ucf101.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6453c8d3b19313e322766bab4486fd4a941c9cb
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/ucf101.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Callable, Dict, Optional, Type
+
+import torch
+from pytorchvideo.data.clip_sampling import ClipSampler
+
+from .labeled_video_dataset import labeled_video_dataset, LabeledVideoDataset
+
+
+"""
+ Action recognition video dataset for UCF101
+
+"""
+
+
+def Ucf101(
+ data_path: str,
+ clip_sampler: ClipSampler,
+ video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
+ transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ video_path_prefix: str = "",
+ decode_audio: bool = True,
+ decoder: str = "pyav",
+) -> LabeledVideoDataset:
+ """
+ A helper function to create ``LabeledVideoDataset`` object for the Ucf101 dataset.
+
+ Args:
+ data_path (str): Path to the data. The path type defines how the data
+ should be read:
+
+ * For a file path, the file is read and each line is parsed into a
+ video path and label.
+ * For a directory, the directory structure defines the classes
+ (i.e. each subdirectory is a class).
+
+ clip_sampler (ClipSampler): Defines how clips should be sampled from each
+ video. See the clip sampling documentation for more information.
+
+ video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
+ video container. This defines the order videos are decoded and,
+ if necessary, the distributed split.
+
+ transform (Callable): This callable is evaluated on the clip output before
+ the clip is returned. It can be used for user defined preprocessing and
+ augmentations to the clips. See the ``LabeledVideoDataset`` class for clip
+ output format.
+
+ video_path_prefix (str): Path to root directory with the videos that are
+ loaded in ``LabeledVideoDataset``. All the video paths before loading
+ are prefixed with this path.
+
+ decode_audio (bool): If True, also decode audio from video.
+
+ decoder (str): Defines what type of decoder used to decode a video.
+
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.Ucf101")
+
+ return labeled_video_dataset(
+ data_path,
+ clip_sampler,
+ video_sampler,
+ transform,
+ video_path_prefix,
+ decode_audio,
+ decoder,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/data/utils.py b/code/pytorchvideo/pytorchvideo/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d90b65884525e05d902efb2379cc5f98fb70a6ee
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/utils.py
@@ -0,0 +1,397 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from __future__ import annotations
+
+import csv
+import itertools
+import logging
+import math
+import sys
+import threading
+from collections import defaultdict
+from dataclasses import Field, field as dataclass_field, fields as dataclass_fields
+from fractions import Fraction
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import av
+import numpy as np
+import torch
+from iopath.common.file_io import g_pathmgr
+
+
+logger = logging.getLogger(__name__)
+
+
+def thwc_to_cthw(data: torch.Tensor) -> torch.Tensor:
+ """
+ Permute tensor from (time, height, weight, channel) to
+ (channel, height, width, time).
+ """
+ return data.permute(3, 0, 1, 2)
+
+
+def secs_to_pts(
+ time_in_seconds: float,
+ time_base: float,
+ start_pts: int,
+ round_mode: str = "floor",
+) -> int:
+ """
+ Converts a time (in seconds) to the given time base and start_pts offset
+ presentation time. Round_mode specifies the mode of rounding when converting time.
+
+ Returns:
+ pts (int): The time in the given time base.
+ """
+ if time_in_seconds == math.inf:
+ return math.inf
+
+ assert round_mode in ["floor", "ceil"], f"round_mode={round_mode} is not supported!"
+
+ if round_mode == "floor":
+ return math.floor(time_in_seconds / time_base) + start_pts
+ else:
+ return math.ceil(time_in_seconds / time_base) + start_pts
+
+
+def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
+ """
+ Converts a present time with the given time base and start_pts offset to seconds.
+
+ Returns:
+ time_in_seconds (float): The corresponding time in seconds.
+ """
+ if pts == math.inf:
+ return math.inf
+
+ return int(pts - start_pts) * time_base
+
+
+def export_video_array(
+ video: Union[np.ndarray, torch.tensor],
+ output_path: str,
+ rate: Union[str, Fraction],
+ bit_rate: Optional[int] = None,
+ pix_fmt: Optional[str] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ in_format: Optional[str] = "rgb24",
+ out_format: Optional[str] = "bgr24",
+ video_codec: Optional[str] = "mpeg4",
+ options: Optional[Dict[str, Any]] = None,
+) -> av.VideoStream:
+ """
+ Encodes and exports an ndarray or torch tensor representing frames of a video to output_path
+
+ Args:
+ video (Union[np.ndarray, torch.tensor]):
+ A 4d array/tensor returned by EncodedVideoPyAV.get_clip. Axis 0 is channel,
+ Axis 1 is frame index/time, the remaining axes are the frame pixels
+
+ output_path (str):
+ the path to write the video to
+
+ rate (Union[str, Fraction]):
+ the frame rate of the output video
+
+ bit_rate (int):
+ the bit rate of the output video. If not set, defaults to 1024000
+
+ pix_fmt (str):
+ the pixel format of the output video. If not set, defaults to yuv420p
+
+ height (int):
+ the height of the output video. if not set, defaults to the dimensions of input video
+
+ width (int):
+ the width of the output video. if not set, defaults to the dimensions of input video
+
+ in_format (str):
+ The encoding format of the input video. Defaults to rgb24
+
+ out_format (str):
+ The encoding format of the output video. Defaults to bgr24
+
+ video_codec (str):
+ The video codec to use for the output video. Defaults to mpeg4
+
+ options (Dict[str, Any]):
+ Dictionary of options for PyAV video encoder
+ Returns:
+ Stream object which contains metadata about encoded and exported video.
+ """
+ stream = None
+ with g_pathmgr.open(output_path, "wb") as oh:
+ output = av.open(oh, mode="wb", format="mp4")
+ stream = output.add_stream(codec_name=video_codec, rate=rate)
+ if height:
+ stream.height = height
+ else:
+ stream.height = video.shape[-2]
+ if width:
+ stream.width = width
+ else:
+ stream.width = video.shape[-1]
+ if bit_rate:
+ stream.bit_rate = bit_rate
+ if pix_fmt:
+ stream.pix_fmt = pix_fmt
+ else:
+ stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
+ if video_codec == "libx264rgb":
+ out_format = "rgb24"
+ if options:
+ stream.options = options
+ if isinstance(video, torch.Tensor):
+ video = video.numpy()
+ for np_frame in np.moveaxis(video, 0, -1):
+ frame = av.VideoFrame.from_ndarray(
+ np_frame.astype("uint8"), format=in_format
+ )
+ if in_format != out_format:
+ frame = frame.reformat(format=out_format)
+ frame.pict_type = "NONE"
+ for packet in stream.encode(frame):
+ output.mux(packet)
+ for packet in stream.encode():
+ output.mux(packet)
+ output.close()
+ return stream
+
+
+class MultiProcessSampler(torch.utils.data.Sampler):
+ """
+ MultiProcessSampler splits sample indices from a PyTorch Sampler evenly across
+ workers spawned by a PyTorch DataLoader.
+ """
+
+ def __init__(self, sampler: torch.utils.data.Sampler) -> None:
+ self._sampler = sampler
+
+ def __iter__(self):
+ """
+ Returns:
+ Iterator for underlying PyTorch Sampler indices split by worker id.
+ """
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None and worker_info.num_workers != 0:
+
+ # Split sampler indexes by worker.
+ video_indexes = range(len(self._sampler))
+ worker_splits = np.array_split(video_indexes, worker_info.num_workers)
+ worker_id = worker_info.id
+ worker_split = worker_splits[worker_id]
+ if len(worker_split) == 0:
+ logger.warning(
+ f"More data workers({worker_info.num_workers}) than videos"
+ f"({len(self._sampler)}). For optimal use of processes "
+ "reduce num_workers."
+ )
+ return iter(())
+
+ iter_start = worker_split[0]
+ iter_end = worker_split[-1] + 1
+ worker_sampler = itertools.islice(iter(self._sampler), iter_start, iter_end)
+ else:
+
+ # If no worker processes found, we return the full sampler.
+ worker_sampler = iter(self._sampler)
+
+ return worker_sampler
+
+
+def optional_threaded_foreach(
+ target: Callable, args_iterable: Iterable[Tuple], multithreaded: bool
+):
+ """
+ Applies 'target' function to each Tuple args in 'args_iterable'.
+ If 'multithreaded' a thread is spawned for each function application.
+
+ Args:
+ target (Callable):
+ A function that takes as input the parameters in each args_iterable Tuple.
+
+ args_iterable (Iterable[Tuple]):
+ An iterable of the tuples each containing a set of parameters to pass to
+ target.
+
+ multithreaded (bool):
+ Whether or not the target applications are parallelized by thread.
+ """
+
+ if multithreaded:
+ threads = []
+ for args in args_iterable:
+ thread = threading.Thread(target=target, args=args)
+ thread.start()
+ threads.append(thread)
+
+ for t in threads: # Wait for all threads to complete
+ t.join()
+ else:
+ for args in args_iterable:
+ target(*args)
+
+
+class DataclassFieldCaster:
+ """
+ Class to allow subclasses wrapped in @dataclass to automatically
+ cast fields to their relevant type by default.
+
+ Also allows for an arbitrary intialization function to be applied
+ for a given field.
+ """
+
+ COMPLEX_INITIALIZER = "DataclassFieldCaster__complex_initializer"
+
+ def __post_init__(self) -> None:
+ f"""
+ This function is run by the dataclass library after '__init__'.
+
+ Here we use this to ensure all fields are casted to their declared types
+ and to apply any complex field_initializer functions that have been
+ declared via the 'complex_initialized_dataclass_field' method of
+ this class.
+
+ A complex field_initializer for a given field would be stored in the
+ field.metadata dictionary at:
+ key = '{self.COMPLEX_INITIALIZER}' (self.COMPLEX_INITIALIZER)
+
+ """
+ for field in dataclass_fields(self):
+ value = getattr(self, field.name)
+ # First check if the datafield has been set to the declared type or
+ # if the datafield has a declared complex field_initializer.
+ if (
+ not isinstance(value, field.type)
+ or DataclassFieldCaster.COMPLEX_INITIALIZER in field.metadata
+ ):
+ # Apply the complex field_initializer function for this field's value,
+ # assert that the resultant type is the declared type of the field.
+ if DataclassFieldCaster.COMPLEX_INITIALIZER in field.metadata:
+ setattr(
+ self,
+ field.name,
+ field.metadata[DataclassFieldCaster.COMPLEX_INITIALIZER](value),
+ )
+ assert isinstance(getattr(self, field.name), field.type), (
+ f"'field_initializer' function of {field.name} must return "
+ f"type {field.type} but returned type {type(getattr(self, field.name))}"
+ )
+ else:
+ # Otherwise attempt to cast the field's value to its declared type.
+ setattr(self, field.name, field.type(value))
+
+ @staticmethod
+ def complex_initialized_dataclass_field(
+ field_initializer: Callable, **kwargs
+ ) -> Field:
+ """
+ Allows for the setting of a function to be called on the
+ named parameter associated with a field during initialization,
+ after __init__() completes.
+
+ Args:
+ field_initializer (Callable):
+ The function to be called on the field
+
+ **kwargs: To be passed downstream to the dataclasses.field method
+
+ Returns:
+ (dataclasses.Field) that contains the field_initializer and kwargs infoÎ
+ """
+ metadata = kwargs.get("metadata") or {}
+ assert DataclassFieldCaster.COMPLEX_INITIALIZER not in metadata
+ metadata[DataclassFieldCaster.COMPLEX_INITIALIZER] = field_initializer
+ kwargs["metadata"] = metadata
+ return dataclass_field(**kwargs)
+
+
+def load_dataclass_dict_from_csv(
+ input_csv_file_path: str,
+ dataclass_class: type,
+ dict_key_field: str,
+ list_per_key: bool = False,
+) -> Dict[Any, Union[Any, List[Any]]]:
+ """
+ Args:
+ input_csv_file_path (str): File path of the csv to read from
+ dataclass_class (type): The dataclass to read each row into.
+ dict_key_field (str): The field of 'dataclass_class' to use as
+ the dictionary key.
+ list_per_key (bool) = False: If the output data structure
+ contains a list of dataclass objects per key, rather than a
+ single unique dataclass object.
+
+ Returns:
+ Dict[Any, Union[Any, List[Any]] mapping from the dataclass
+ value at attr = dict_key_field to either:
+
+ if 'list_per_key', a list of all dataclass objects that
+ have equal values at attr = dict_key_field, equal to the key
+
+ if not 'list_per_key', the unique dataclass object
+ for which the value at attr = dict_key_field is equal to the key
+
+ Raises:
+ AssertionError: if not 'list_per_key' and there are
+ dataclass obejcts with equal values at attr = dict_key_field
+ """
+
+ output_dict = defaultdict(list) if list_per_key else {}
+ with g_pathmgr.open(input_csv_file_path) as dataclass_file:
+ reader = csv.reader(dataclass_file, delimiter=",", quotechar='"')
+ column_index = {header: i for i, header in enumerate(next(reader))}
+ for line in reader:
+ datum = dataclass_class(
+ *(
+ line[column_index[field.name]]
+ for field in dataclass_fields(dataclass_class)
+ )
+ )
+ dict_key = getattr(datum, dict_key_field)
+ if list_per_key:
+ output_dict[dict_key].append(datum)
+ else:
+ assert (
+ dict_key not in output_dict
+ ), f"Multiple entries for {output_dict} in {dataclass_file}"
+ output_dict[dict_key] = datum
+ return output_dict
+
+
+def save_dataclass_objs_to_headered_csv(
+ dataclass_objs: List[Any], file_name: str
+) -> None:
+ """
+ Saves a list of @dataclass objects to the specified csv file.
+
+ Args:
+ dataclass_objs (List[Any]):
+ A list of @dataclass objects to be saved.
+
+ file_name (str):
+ file_name to save csv data to.
+ """
+ dataclass_type = type(dataclass_objs[0])
+ field_names = [f.name for f in dataclass_fields(dataclass_type)]
+ with g_pathmgr.open(file_name, "w") as f:
+ writer = csv.writer(f, delimiter=",", quotechar='"')
+ writer.writerow(field_names)
+ for obj in dataclass_objs:
+ writer.writerow([getattr(obj, f) for f in field_names])
+
+
+def get_logger(name: str) -> logging.Logger:
+ logger: logging.Logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ if not logger.hasHandlers():
+ sh = logging.StreamHandler(sys.stdout)
+ sh.setFormatter(
+ logging.Formatter(
+ "[%(asctime)s] %(levelname)s %(message)s \t[%(filename)s.%(funcName)s:%(lineno)d]", # noqa
+ datefmt="%y%m%d %H:%M:%S",
+ )
+ )
+ logger.addHandler(sh)
+ return logger
diff --git a/code/pytorchvideo/pytorchvideo/data/video.py b/code/pytorchvideo/pytorchvideo/data/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..090edb5ae9479dfdb321d1960462878e11bcadd0
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/data/video.py
@@ -0,0 +1,101 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from abc import ABC, abstractmethod
+from typing import BinaryIO, Dict, Optional
+
+import torch
+from iopath.common.file_io import g_pathmgr
+
+
+class VideoPathHandler(object):
+ """
+ Utility class that handles all deciphering and caching of video paths for
+ encoded and frame videos.
+ """
+
+ def __init__(self) -> None:
+ # Pathmanager isn't guaranteed to be in correct order,
+ # sorting is expensive, so we cache paths in case of frame video and reuse.
+ self.path_order_cache = {}
+
+ def video_from_path(
+ self, filepath, decode_video=True, decode_audio=False, decoder="pyav", fps=30
+ ):
+ try:
+ is_file = g_pathmgr.isfile(filepath)
+ is_dir = g_pathmgr.isdir(filepath)
+ except NotImplementedError:
+
+ # Not all PathManager handlers support is{file,dir} functions, when this is the
+ # case, we default to assuming the path is a file.
+ is_file = True
+ is_dir = False
+
+ if is_file:
+ from pytorchvideo.data.encoded_video import EncodedVideo
+
+ return EncodedVideo.from_path(
+ filepath,
+ decode_video=decode_video,
+ decode_audio=decode_audio,
+ decoder=decoder,
+ )
+ elif is_dir:
+ from pytorchvideo.data.frame_video import FrameVideo
+
+ assert not decode_audio, "decode_audio must be False when using FrameVideo"
+ return FrameVideo.from_directory(
+ filepath, fps, path_order_cache=self.path_order_cache
+ )
+ else:
+ raise FileNotFoundError(f"{filepath} not found.")
+
+
+class Video(ABC):
+ """
+ Video provides an interface to access clips from a video container.
+ """
+
+ @abstractmethod
+ def __init__(
+ self,
+ file: BinaryIO,
+ video_name: Optional[str] = None,
+ decode_audio: bool = True,
+ ) -> None:
+ """
+ Args:
+ file (BinaryIO): a file-like object (e.g. io.BytesIO or io.StringIO) that
+ contains the encoded video.
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def duration(self) -> float:
+ """
+ Returns:
+ duration of the video in seconds
+ """
+ pass
+
+ @abstractmethod
+ def get_clip(
+ self, start_sec: float, end_sec: float
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """
+ Retrieves frames from the internal video at the specified start and end times
+ in seconds (the video always starts at 0 seconds).
+
+ Args:
+ start_sec (float): the clip start time in seconds
+ end_sec (float): the clip end time in seconds
+ Returns:
+ video_data_dictonary: A dictionary mapping strings to tensor of the clip's
+ underlying data.
+
+ """
+ pass
+
+ def close(self):
+ pass
diff --git a/code/pytorchvideo/pytorchvideo/layers/__init__.py b/code/pytorchvideo/pytorchvideo/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc0715e08c2d3d8d6b4e9c8cdd31a4ff40719f49
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+from .attention import Mlp, MultiScaleAttention, MultiScaleBlock
+from .attention_torchscript import ScriptableMultiScaleBlock
+from .drop_path import DropPath
+from .fusion import ConcatFusion, make_fusion_layer, ReduceFusion
+from .mlp import make_multilayer_perceptron
+from .positional_encoding import PositionalEncoding, SpatioTemporalClsPositionalEncoding
+from .positional_encoding_torchscript import (
+ ScriptableSpatioTemporalClsPositionalEncoding,
+)
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/__init__.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/__init__.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..db74384cc41ba17fb6fedd4b1aa3c5eefc0d2dba
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py
@@ -0,0 +1,103 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""
+This file contains supported activation functions in efficient block and helper code.
+All supported activation functions are child class of EfficientBlockBase, and included
+in supported_act_functions.
+"""
+import torch
+import torch.nn as nn
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+from pytorchvideo.layers.swish import Swish as SwishCustomOp
+
+
+class _NaiveSwish(nn.Module):
+ """
+ Helper class to implement naive swish for deploy. It is not intended to be used to
+ build network.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.mul_func = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ return self.mul_func.mul(x, torch.sigmoid(x))
+
+
+class Swish(EfficientBlockBase):
+ """
+ Swish activation function for efficient block. When in original form for training,
+ using custom op version of swish for better training memory efficiency. When in
+ deployable form, use naive swish as custom op is not supported to run on Pytorch
+ Mobile. For better latency on mobile CPU, use HardSwish instead.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.act = SwishCustomOp()
+
+ def forward(self, x):
+ return self.act(x)
+
+ def convert(self, *args, **kwarg):
+ self.act = _NaiveSwish()
+
+
+class HardSwish(EfficientBlockBase):
+ """
+ Hardswish activation function. It is natively supported by Pytorch Mobile, and has
+ better latency than Swish in int8 mode.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.act = nn.Hardswish()
+
+ def forward(self, x):
+ return self.act(x)
+
+ def convert(self, *args, **kwarg):
+ pass
+
+
+class ReLU(EfficientBlockBase):
+ """
+ ReLU activation function for EfficientBlockBase.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.act = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.act(x)
+
+ def convert(self, *args, **kwarg):
+ pass
+
+
+class Identity(EfficientBlockBase):
+ """
+ Identity operation for EfficientBlockBase.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.act = nn.Identity()
+
+ def forward(self, x):
+ return self.act(x)
+
+ def convert(self, *args, **kwarg):
+ pass
+
+
+supported_act_functions = {
+ "relu": ReLU,
+ "swish": Swish,
+ "hswish": HardSwish,
+ "identity": Identity,
+}
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/attention.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a6309e4a30e2f12dab03b94273d98ed17a2e888
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/attention.py
@@ -0,0 +1,109 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from copy import deepcopy
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from fvcore.nn.squeeze_excitation import SqueezeExcitation as SqueezeExcitationFVCore
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+
+from .conv_helper import _Reshape, _SkipConnectMul
+
+
+class SqueezeExcitation(EfficientBlockBase):
+ """
+ Efficient Squeeze-Excitation (SE). The Squeeze-Excitation block is described in:
+ *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
+ This implementation has the same instantiation interface as SE implementation in
+ fvcore, and in original mode for training it is just a wrapped version of SE in
+ fvcore. Since conv3d in original SE implementation of fvcore is not well supported
+ by QNNPACK, here convert() method is implemented which converts class instance into
+ a equivalent efficient deployable form.
+
+ convert_flag variable is to record whether the SqueezeExcitation instance
+ has been converted; SqueezeExcitation is in original form if convert_flag is false,
+ while it is in deployable form if convert_flag is true.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ num_channels_reduced: Optional[int] = None,
+ reduction_ratio: float = 2.0,
+ is_3d: bool = False,
+ activation: Optional[nn.Module] = None,
+ ) -> None:
+ """
+ Args:
+ num_channels (int): Number of input channels.
+ num_channels_reduced (int):
+ Number of reduced channels. If none, uses reduction_ratio to calculate.
+ reduction_ratio (float):
+ How much num_channels should be reduced if num_channels_reduced is not provided.
+ is_3d (bool): Whether we're operating on 3d data (or 2d), default 2d.
+ activation (nn.Module): Activation function used, defaults to ReLU.
+ """
+ super().__init__()
+ # Implement SE from FVCore here for training.
+ self.se = SqueezeExcitationFVCore(
+ num_channels,
+ num_channels_reduced=num_channels_reduced,
+ reduction_ratio=reduction_ratio,
+ is_3d=is_3d,
+ activation=activation,
+ )
+ self.is_3d = is_3d
+ self.convert_flag = False
+
+ def convert(self, input_blob_size, **kwargs):
+ """
+ Converts into efficient version of squeeze-excite (SE) for CPU.
+ It changes conv in original SE into linear layer (better supported by CPU).
+ """
+ if self.is_3d:
+ avg_pool = nn.AdaptiveAvgPool3d(1)
+ else:
+ avg_pool = nn.AdaptiveAvgPool2d(1)
+ """
+ Reshape tensor size to (B, C) for linear layer.
+ """
+ reshape0 = _Reshape((input_blob_size[0], input_blob_size[1]))
+ fc0 = nn.Linear(
+ self.se.block[0].in_channels,
+ self.se.block[0].out_channels,
+ bias=(not (self.se.block[0].bias is None)),
+ )
+ state_dict_fc0 = deepcopy(self.se.block[0].state_dict())
+ state_dict_fc0["weight"] = state_dict_fc0["weight"].squeeze()
+ fc0.load_state_dict(state_dict_fc0)
+ activation = deepcopy(self.se.block[1])
+ fc1 = nn.Linear(
+ self.se.block[2].in_channels,
+ self.se.block[2].out_channels,
+ bias=(not (self.se.block[2].bias is None)),
+ )
+ state_dict_fc1 = deepcopy(self.se.block[2].state_dict())
+ state_dict_fc1["weight"] = state_dict_fc1["weight"].squeeze()
+ fc1.load_state_dict(state_dict_fc1)
+ sigmoid = deepcopy(self.se.block[3])
+ """
+ Output of linear layer has output shape of (B, C). Need to reshape to proper
+ shape before multiplying with input tensor.
+ """
+ reshape_size_after_sigmoid = (input_blob_size[0], input_blob_size[1], 1, 1) + (
+ (1,) if self.is_3d else ()
+ )
+ reshape1 = _Reshape(reshape_size_after_sigmoid)
+ se_layers = nn.Sequential(
+ avg_pool, reshape0, fc0, activation, fc1, sigmoid, reshape1
+ )
+ # Add final elementwise multiplication and replace self.se
+ self.se = _SkipConnectMul(se_layers)
+ self.convert_flag = True
+
+ def forward(self, x) -> torch.Tensor:
+ out = self.se(x)
+ return out
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d9d7c228c92e93dca3ca889ed62c9c8350ab955
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py
@@ -0,0 +1,556 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""
+This file contains helper classes for building conv3d efficient blocks.
+The helper classes are intended to be instantiated inside efficient block,
+not to be used by user to build network.
+"""
+
+from copy import deepcopy
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+
+class _Reshape(nn.Module):
+ """
+ Helper class to implement data reshape as a module.
+ Args:
+ reshape_size (tuple): size of data after reshape.
+ """
+
+ def __init__(
+ self,
+ reshape_size: Tuple,
+ ):
+ super().__init__()
+ self.reshape_size = reshape_size
+
+ def forward(self, x):
+ return torch.reshape(x, self.reshape_size)
+
+
+class _SkipConnectMul(nn.Module):
+ """
+ Helper class to implement skip multiplication.
+ Args:
+ layer (nn.Module): layer for skip multiplication. With input x, _SkipConnectMul
+ implements layer(x)*x.
+ """
+
+ def __init__(
+ self,
+ layer: nn.Module,
+ ):
+ super().__init__()
+ self.layer = layer
+ self.mul_func = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ return self.mul_func.mul(x, self.layer(x))
+
+
+class _Conv3dTemporalKernel3Decomposed(nn.Module):
+ """
+ Helper class for decomposing conv3d with temporal kernel of 3 into equivalent conv2ds.
+ In conv3d with temporal kernel 3 and input I, for output temporal index of t (O[:,:,t,:,:]),
+ the conv can be expressed as:
+ O[:,:,t,:,:] = conv3d(I[:,:,t:t+3,:,:])
+ = conv2d_0(I[:,:,t,:,:]) + conv2d_1(I[:,:,t+1,:,:]) + conv2d_2(I[:,:,t+2,:,:])
+ If bias is considered:
+ O[:,:,t,:,:] = conv3d_w_bias(I[:,:,t:t+3,:,:])
+ = conv2d_0_wo_bias(I[:,:,t,:,:])
+ + conv2d_1_w_bias(I[:,:,t+1,:,:]) + conv2d_2_wo_bias(I[:,:,t+2,:,:])
+ The input Conv3d also needs zero padding of size 1 in temporal dimension.
+ """
+
+ def __init__(
+ self,
+ conv3d_in: nn.Conv3d,
+ input_THW_tuple: Tuple,
+ ):
+ """
+ Args:
+ conv3d_in (nn.Module): input nn.Conv3d module to be converted
+ into equivalent conv2d.
+ input_THW_tuple (tuple): input THW size for conv3d_in during forward.
+ """
+ super().__init__()
+ assert conv3d_in.padding[0] == 1, (
+ "_Conv3dTemporalKernel3Eq only support temporal padding of 1, "
+ f"but got {conv3d_in.padding[0]}"
+ )
+ assert conv3d_in.padding_mode == "zeros", (
+ "_Conv3dTemporalKernel3Eq only support zero padding, "
+ f"but got {conv3d_in.padding_mode}"
+ )
+ self._input_THW_tuple = input_THW_tuple
+ padding_2d = conv3d_in.padding[1:]
+ in_channels = conv3d_in.in_channels
+ out_channels = conv3d_in.out_channels
+ kernel_size = conv3d_in.kernel_size[1:]
+ groups = conv3d_in.groups
+ stride_2d = conv3d_in.stride[1:]
+ # Create 3 conv2d to emulate conv3d.
+ if (
+ self._input_THW_tuple[0] > 1
+ ): # Those two conv2d are needed only when temporal input > 1.
+ self._conv2d_3_3_0 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=padding_2d,
+ stride=stride_2d,
+ groups=groups,
+ bias=False,
+ )
+ self._conv2d_3_3_2 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=padding_2d,
+ stride=stride_2d,
+ groups=groups,
+ bias=False,
+ )
+ self._conv2d_3_3_1 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=padding_2d,
+ stride=stride_2d,
+ groups=groups,
+ bias=(conv3d_in.bias is not None),
+ )
+
+ state_dict = conv3d_in.state_dict()
+ state_dict_1 = deepcopy(state_dict)
+ state_dict_1["weight"] = state_dict["weight"][:, :, 1]
+ self._conv2d_3_3_1.load_state_dict(state_dict_1)
+
+ if self._input_THW_tuple[0] > 1:
+ state_dict_0 = deepcopy(state_dict)
+ state_dict_0["weight"] = state_dict["weight"][:, :, 0]
+ if conv3d_in.bias is not None:
+ """
+ Don't need bias for other conv2d instances to avoid duplicated addition of bias.
+ """
+ state_dict_0.pop("bias")
+ self._conv2d_3_3_0.load_state_dict(state_dict_0)
+
+ state_dict_2 = deepcopy(state_dict)
+ state_dict_2["weight"] = state_dict["weight"][:, :, 2]
+ if conv3d_in.bias is not None:
+ state_dict_2.pop("bias")
+ self._conv2d_3_3_2.load_state_dict(state_dict_2)
+
+ self._add_funcs = nn.ModuleList(
+ [
+ nn.quantized.FloatFunctional()
+ for _ in range(2 * (self._input_THW_tuple[0] - 1))
+ ]
+ )
+ self._cat_func = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """
+ Use three conv2d to emulate conv3d.
+ This forward assumes zero padding of size 1 in temporal dimension.
+ """
+ if self._input_THW_tuple[0] > 1:
+ out_tensor_list = []
+ """
+ First output plane in temporal dimension,
+ conv2d_3_3_0 is skipped due to zero padding.
+ """
+ cur_tensor = (
+ self._add_funcs[0]
+ .add(self._conv2d_3_3_1(x[:, :, 0]), self._conv2d_3_3_2(x[:, :, 1]))
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ for idx in range(2, self._input_THW_tuple[0]):
+ cur_tensor = (
+ self._add_funcs[2 * idx - 3]
+ .add(
+ self._add_funcs[2 * idx - 2].add(
+ self._conv2d_3_3_0(x[:, :, idx - 2]),
+ self._conv2d_3_3_1(x[:, :, idx - 1]),
+ ),
+ self._conv2d_3_3_2(x[:, :, idx]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ """
+ Last output plane in temporal domain, conv2d_3_3_2 is skipped due to zero padding.
+ """
+ cur_tensor = (
+ self._add_funcs[-1]
+ .add(self._conv2d_3_3_0(x[:, :, -2]), self._conv2d_3_3_1(x[:, :, -1]))
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ return self._cat_func.cat(out_tensor_list, 2)
+ else: # Degenerated to simple conv2d
+ return self._conv2d_3_3_1(x[:, :, 0]).unsqueeze(2)
+
+
+class _Conv3dTemporalKernel5Decomposed(nn.Module):
+ """
+ Helper class for decomposing conv3d with kernel size of (5, k, k) into equivalent conv2ds.
+ In such conv3d and input I, for output temporal index of t (O[:,:,t,:,:]), the conv
+ can be expressed as:
+ O[:,:,t,:,:] = conv3d(I[:,:,t:t+5,:,:])
+ = conv2d_0(I[:,:,t,:,:]) + conv2d_1(I[:,:,t+1,:,:]) + conv2d_2(I[:,:,t+2,:,:])
+ + conv2d_3(I[:,:,t+3,:,:]) + conv2d_4(I[:,:,t+4,:,:])
+ If bias is considered:
+ O[:,:,t,:,:] = conv3d_w_bias(I[:,:,t:t+3,:,:])
+ = conv2d_0_wo_bias(I[:,:,t,:,:])
+ + conv2d_1_wo_bias(I[:,:,t+1,:,:]) + conv2d_2_w_bias(I[:,:,t+2,:,:])
+ + conv2d_3_wo_bias(I[:,:,t+1,:,:]) + conv2d_4_wo_bias(I[:,:,t+2,:,:])
+ The input Conv3d also needs zero padding of size 2 in temporal dimension at begin and end.
+ """
+
+ def __init__(
+ self,
+ conv3d_in: nn.Conv3d,
+ thw_shape: Tuple[int, int, int],
+ ):
+ """
+ Args:
+ conv3d_in (nn.Module): input nn.Conv3d module to be converted
+ into equivalent conv2d.
+ thw_shape (tuple): input THW size for conv3d_in during forward.
+ """
+ super().__init__()
+ assert conv3d_in.padding[0] == 2, (
+ "_Conv3dTemporalKernel5Eq only support temporal padding of 2, "
+ f"but got {conv3d_in.padding[0]}"
+ )
+ assert conv3d_in.padding_mode == "zeros", (
+ "_Conv3dTemporalKernel5Eq only support zero padding, "
+ f"but got {conv3d_in.padding_mode}"
+ )
+ self._thw_shape = thw_shape
+ padding_2d = conv3d_in.padding[1:]
+ in_channels = conv3d_in.in_channels
+ out_channels = conv3d_in.out_channels
+ kernel_size = conv3d_in.kernel_size[1:]
+ groups = conv3d_in.groups
+ stride_2d = conv3d_in.stride[1:]
+ # Create 3 conv2d to emulate conv3d.
+ t, h, w = self._thw_shape
+ args_dict = {
+ "in_channels": in_channels,
+ "out_channels": out_channels,
+ "kernel_size": kernel_size,
+ "padding": padding_2d,
+ "stride": stride_2d,
+ "groups": groups,
+ }
+
+ for iter_idx in range(5):
+ if iter_idx != 2:
+ if t > 1: # Those four conv2d are needed only when temporal input > 1.
+ self.add_module(
+ f"_conv2d_{iter_idx}", nn.Conv2d(**args_dict, bias=False)
+ )
+ else: # _conv2d_2 is needed for all circumstances.
+ self.add_module(
+ f"_conv2d_{iter_idx}",
+ nn.Conv2d(**args_dict, bias=(conv3d_in.bias is not None)),
+ )
+
+ # State dict for _conv2d_2
+ original_state_dict = conv3d_in.state_dict()
+ state_dict_to_load = deepcopy(original_state_dict)
+ state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 2]
+ self._conv2d_2.load_state_dict(state_dict_to_load)
+
+ if t > 1:
+ if conv3d_in.bias is not None:
+ # Don't need bias for other conv2d instances to avoid duplicated
+ # addition of bias.
+ state_dict_to_load.pop("bias")
+ # State dict for _conv2d_0, _conv2d_1, _conv2d_3, _conv2d_4
+ state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 0]
+ self._conv2d_0.load_state_dict(state_dict_to_load)
+
+ state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 1]
+ self._conv2d_1.load_state_dict(state_dict_to_load)
+
+ state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 3]
+ self._conv2d_3.load_state_dict(state_dict_to_load)
+
+ state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 4]
+ self._conv2d_4.load_state_dict(state_dict_to_load)
+ # Elementwise add are needed in forward function, use nn.quantized.FloatFunctional()
+ # for better quantization support. One convolution needs at most 4 elementwise adds
+ # without zero padding; for boundary planes fewer elementwise adds are needed.
+ # See forward() for more details.
+ self._add_funcs = nn.ModuleList(
+ [nn.quantized.FloatFunctional() for _ in range(4 * t - 6)]
+ )
+ self._cat_func = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """
+ Use three conv2d to emulate conv3d.
+ Args:
+ x (torch.Tensor): 5D tensor of (B, C, T, H, W)
+ """
+ t, h, w = self._thw_shape
+ out_tensor_list = []
+ if (
+ t == 1
+ ): # Degenerated to simple conv2d, but make sure output still has T dimension
+ return self._conv2d_2(x[:, :, 0]).unsqueeze(2)
+ elif t == 2:
+ # out_tensor_list[0]: conv2d_1_1_0, conv2d_1_1_1 and conv2d_1_1_4 are
+ # applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[0]
+ .add(self._conv2d_2(x[:, :, 0]), self._conv2d_3(x[:, :, 1]))
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ # out_tensor_list[1]: conv2d_1_1_0, conv2d_1_1_3 and conv2d_1_1_4 are
+ # applied to zero padding.
+
+ cur_tensor = (
+ self._add_funcs[1]
+ .add(self._conv2d_1(x[:, :, 0]), self._conv2d_2(x[:, :, 1]))
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ elif t == 3:
+ # out_tensor_list[0]: conv2d_1_1_0, conv2d_1_1_1 are applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[0]
+ .add(
+ self._add_funcs[1].add(
+ self._conv2d_2(x[:, :, 0]), self._conv2d_3(x[:, :, 1])
+ ),
+ self._conv2d_4(x[:, :, 2]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ # out_tensor_list[1]: conv2d_1_1_0, conv2d_1_1_4 are applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[2]
+ .add(
+ self._add_funcs[3].add(
+ self._conv2d_1(x[:, :, 0]), self._conv2d_2(x[:, :, 1])
+ ),
+ self._conv2d_3(x[:, :, 2]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ # out_tensor_list[2]: conv2d_1_1_3, conv2d_1_1_4 are applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[4]
+ .add(
+ self._add_funcs[5].add(
+ self._conv2d_0(x[:, :, 0]), self._conv2d_1(x[:, :, 1])
+ ),
+ self._conv2d_2(x[:, :, 2]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ elif t == 4:
+ # out_tensor_list[0]: conv2d_1_1_0, conv2d_1_1_1 are applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[0]
+ .add(
+ self._add_funcs[1].add(
+ self._conv2d_2(x[:, :, 0]), self._conv2d_3(x[:, :, 1])
+ ),
+ self._conv2d_4(x[:, :, 2]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ # out_tensor_list[1]: conv2d_1_1_0 is applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[2]
+ .add(
+ self._add_funcs[3].add(
+ self._add_funcs[4].add(
+ self._conv2d_1(x[:, :, 0]),
+ self._conv2d_2(x[:, :, 1]),
+ ),
+ self._conv2d_3(x[:, :, 2]),
+ ),
+ self._conv2d_4(x[:, :, 3]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ # out_tensor_list[2]: conv2d_1_1_4 is applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[5]
+ .add(
+ self._add_funcs[6].add(
+ self._add_funcs[7].add(
+ self._conv2d_0(x[:, :, 0]),
+ self._conv2d_1(x[:, :, 1]),
+ ),
+ self._conv2d_2(x[:, :, 2]),
+ ),
+ self._conv2d_3(x[:, :, 3]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ # out_tensor_list[3]: conv2d_1_1_3, conv2d_1_1_4 are applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[8]
+ .add(
+ self._add_funcs[9].add(
+ self._conv2d_0(x[:, :, 1]), self._conv2d_1(x[:, :, 2])
+ ),
+ self._conv2d_2(x[:, :, 3]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ else: # t >= 5
+ # out_tensor_list[0]: conv2d_1_1_0, conv2d_1_1_1 are applied to zero padding.
+ add_func_idx_base = 0
+ cur_tensor = (
+ self._add_funcs[add_func_idx_base]
+ .add(
+ self._add_funcs[add_func_idx_base + 1].add(
+ self._conv2d_2(x[:, :, 0]), self._conv2d_3(x[:, :, 1])
+ ),
+ self._conv2d_4(x[:, :, 2]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ add_func_idx_base += 2
+ # out_tensor_list[1]: conv2d_1_1_0 is applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[add_func_idx_base]
+ .add(
+ self._add_funcs[add_func_idx_base + 1].add(
+ self._add_funcs[add_func_idx_base + 2].add(
+ self._conv2d_1(x[:, :, 0]),
+ self._conv2d_2(x[:, :, 1]),
+ ),
+ self._conv2d_3(x[:, :, 2]),
+ ),
+ self._conv2d_4(x[:, :, 3]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ add_func_idx_base += 3
+ # out_tensor_list[2:-2]: zero padding has no effect.
+ for idx in range(4, t):
+ cur_tensor = (
+ self._add_funcs[add_func_idx_base]
+ .add(
+ self._add_funcs[add_func_idx_base + 1].add(
+ self._add_funcs[add_func_idx_base + 2].add(
+ self._add_funcs[add_func_idx_base + 3].add(
+ self._conv2d_0(x[:, :, idx - 4]),
+ self._conv2d_1(x[:, :, idx - 3]),
+ ),
+ self._conv2d_2(x[:, :, idx - 2]),
+ ),
+ self._conv2d_3(x[:, :, idx - 1]),
+ ),
+ self._conv2d_4(x[:, :, idx]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ add_func_idx_base += 4
+ # out_tensor_list[-2]: conv2d_1_1_4 is applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[add_func_idx_base]
+ .add(
+ self._add_funcs[add_func_idx_base + 1].add(
+ self._add_funcs[add_func_idx_base + 2].add(
+ self._conv2d_0(x[:, :, -4]),
+ self._conv2d_1(x[:, :, -3]),
+ ),
+ self._conv2d_2(x[:, :, -2]),
+ ),
+ self._conv2d_3(x[:, :, -1]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ add_func_idx_base += 3
+ # out_tensor_list[-1]: conv2d_1_1_3, conv2d_1_1_4 are applied to zero padding.
+ cur_tensor = (
+ self._add_funcs[add_func_idx_base]
+ .add(
+ self._add_funcs[add_func_idx_base + 1].add(
+ self._conv2d_0(x[:, :, -3]),
+ self._conv2d_1(x[:, :, -2]),
+ ),
+ self._conv2d_2(x[:, :, -1]),
+ )
+ .unsqueeze(2)
+ )
+ out_tensor_list.append(cur_tensor)
+ return self._cat_func.cat(out_tensor_list, 2)
+
+
+class _Conv3dTemporalKernel1Decomposed(nn.Module):
+ """
+ Helper class for decomposing conv3d with temporal kernel of 1 into conv2d on
+ multiple temporal planes.
+ In conv3d with temporal kernel 1 and input I, for output temporal index of t (O[:,:,t,:,:]),
+ the conv can be expressed as:
+ O[:,:,t,:,:] = conv3d(I[:,:,t,:,:])
+ = conv2d(I[:,:,t,:,:])
+ The full output can be obtained by concat O[:,:,t,:,:] for t in 0...T,
+ where T is the length of I in temporal dimension.
+ """
+
+ def __init__(
+ self,
+ conv3d_eq: nn.Conv3d,
+ input_THW_tuple: Tuple,
+ ):
+ """
+ Args:
+ conv3d_eq (nn.Module): input nn.Conv3d module to be converted
+ into equivalent conv2d.
+ input_THW_tuple (tuple): input THW size for conv3d_eq during forward.
+ """
+ super().__init__()
+ # create equivalent conv2d module
+ in_channels = conv3d_eq.in_channels
+ out_channels = conv3d_eq.out_channels
+ bias_flag = conv3d_eq.bias is not None
+ self.conv2d_eq = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=(conv3d_eq.kernel_size[1], conv3d_eq.kernel_size[2]),
+ stride=(conv3d_eq.stride[1], conv3d_eq.stride[2]),
+ groups=conv3d_eq.groups,
+ bias=bias_flag,
+ padding=(conv3d_eq.padding[1], conv3d_eq.padding[2]),
+ dilation=(conv3d_eq.dilation[1], conv3d_eq.dilation[2]),
+ )
+ state_dict = conv3d_eq.state_dict()
+ state_dict["weight"] = state_dict["weight"].squeeze(2)
+ self.conv2d_eq.load_state_dict(state_dict)
+ self.input_THW_tuple = input_THW_tuple
+
+ def forward(self, x):
+ out_tensor_list = []
+ for idx in range(self.input_THW_tuple[0]):
+ cur_tensor = self.conv2d_eq(x[:, :, idx]).unsqueeze(2)
+ out_tensor_list.append(cur_tensor)
+ return torch.cat(out_tensor_list, 2)
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e29b074855f0a2f766c7495047e70b435bd953
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py
@@ -0,0 +1,629 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+from collections import OrderedDict
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+
+from .activation_functions import supported_act_functions
+from .conv_helper import (
+ _Conv3dTemporalKernel1Decomposed,
+ _Conv3dTemporalKernel3Decomposed,
+ _Conv3dTemporalKernel5Decomposed,
+ _Reshape,
+)
+
+
+TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
+if TORCH_VERSION >= (1, 11):
+ from torch.ao.quantization import fuse_modules
+else:
+ from torch.quantization import fuse_modules
+
+
+class Conv3dPwBnAct(EfficientBlockBase):
+ """
+ Implements Conv3d + Bn + Activation for pointwise layers.
+ The conv layer has fixed kernel_size = (1,1,1),
+ groups = 1, padding = 0, stride = 1, dilation = 1.
+
+ Input
+ |
+ ↓
+ conv3d (1x1x1)
+ ↓
+ BatchNorm (optional)
+ ↓
+ Activation
+
+ Conv3dPwBnAct is in original form (for training) once instantiated. User can
+ call convert() method to convert it into deployable form for deployment.
+
+ convert_flag variable is to record whether the Conv3dPwBnAct instance
+ has been converted; Conv3dPwBnAct is in original form if convert_flag is false,
+ while it is in deployable form if convert_flag is true.
+
+ Current implementation of this layer in QNNPACK is very efficient.
+ Args:
+ in_channels (int): number of input channels for conv3d 1x1x1.
+ out_channels (int): number of output channels for conv3d 1x1x1.
+ bias (bool): if true, use bias for conv.
+ activation (str): applies selected activation from supported_act_functions.
+ See activation_functions.py for more info about supported activations.
+ Currently ReLU ('relu'), Swish ('swish'), Hardswish ('hswish'), Identity
+ ('identity') are supported.
+ use_bn (bool): if true, use batchnorm.
+ norm_eps (float): epsilon for batchnorm.
+ norm_momentum (float): momentum for batchnorm.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias=False,
+ activation: str = "relu",
+ use_bn=True,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ ):
+ super().__init__()
+ self._in_channels = in_channels
+ self._out_channels = out_channels
+ self.act = activation
+ kernel = OrderedDict()
+ kernel["conv"] = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=bias)
+ if use_bn:
+ kernel["bn"] = nn.BatchNorm3d(
+ out_channels, eps=norm_eps, momentum=norm_momentum
+ )
+ assert (
+ activation in supported_act_functions
+ ), f"Conv3dPwBnAct: {activation} is not in supported_act_functions."
+ kernel["act"] = supported_act_functions[activation]()
+ self.kernel = nn.Sequential(kernel)
+ self.convert_flag = False
+
+ def convert(
+ self,
+ input_blob_size: Tuple,
+ convert_for_quantize: bool = False,
+ native_conv3d_op_qnnpack: bool = False,
+ **kwargs,
+ ):
+ """
+ Converts the block into efficient form.
+ For fp32 operation, or quantized but with older version of QNNPACK w/o native int8
+ Conv3d support, this function converts Conv3d into equivalent Conv2d for Pytorch
+ Mobile deployment.
+ The Conv3d -> Conv2d conversion is done by first fuse conv3d with bn,
+ convert conv3d into equivalent conv2d, and optionally fuse conv2d with relu.
+ After conversion, the forwarding of this module becomes:
+ Input (5d tensor) --> reshape (4d tensor) --> conv2d (4d tensor)
+ --> reshape (5d tensor) --> output (5d tensor)
+
+ For quantized operation on new version of QNNPACK with native int8 Conv3d, this
+ function will only apply operator fusion.
+ Args:
+ input_blob_size (tuple): blob size at the input of Conv3dPwBnAct instance.
+ convert_for_quantize (bool): whether this module is intended to be quantized.
+ native_conv3d_op_qnnpack (bool): whether the QNNPACK version has native int8
+ Conv3d.
+ kwargs (any): any extra keyword arguments from upstream unused by convert().
+ """
+ assert (
+ self.convert_flag is False
+ ), "Conv3dPwBnAct: already converted, cannot be converted again"
+ self.kernel.eval()
+ # First fuse conv and bn if bn exists.
+ if hasattr(self.kernel, "bn"):
+ self.kernel = fuse_modules(self.kernel, ["conv", "bn"])
+ # If user intends to quantize the module and their QNNPACK comes with native int8 Conv3d,
+ # then we just need to do fusion.
+ if convert_for_quantize and native_conv3d_op_qnnpack:
+ if self.act == "relu":
+ self.kernel = fuse_modules(self.kernel, ["conv", "act.act"])
+ # Set new kernel in eval mode again
+ self.kernel.eval()
+ # Else, for fp32 operation or for int8 but with older version of QNNPACK w/o native int8 Conv3d,
+ # we need to unfold Conv3d into Conv2ds.
+ else:
+ batch_size = input_blob_size[0]
+ input_THW_tuple = input_blob_size[2:]
+ self._input_tensor_reshape_size = (
+ batch_size,
+ self._in_channels, # C
+ input_THW_tuple[0] * input_THW_tuple[1], # T*H
+ input_THW_tuple[2], # W
+ )
+ self._output_tensor_size = (
+ batch_size,
+ self._out_channels, # C
+ input_THW_tuple[0], # T
+ input_THW_tuple[1], # H
+ input_THW_tuple[2], # W
+ )
+ conv2d_eq = nn.Conv2d(
+ self._in_channels,
+ self._out_channels,
+ kernel_size=1,
+ bias=(self.kernel.conv.bias is not None),
+ )
+ conv_state_dict = self.kernel.conv.state_dict()
+ conv_state_dict["weight"] = conv_state_dict["weight"].squeeze(2)
+ conv2d_eq.load_state_dict(conv_state_dict)
+ self.kernel.conv = conv2d_eq
+ # Convert activatiopn function
+ self.kernel.act.convert(input_blob_size, **kwargs)
+ # Fuse act with conv after conv3d -> conv2d if act is relu
+ if self.act == "relu":
+ self.kernel = fuse_modules(self.kernel, ["conv", "act.act"])
+ # Insert reshape layers before/after conv2d
+ self.kernel = nn.Sequential(
+ _Reshape(self._input_tensor_reshape_size),
+ self.kernel,
+ _Reshape(self._output_tensor_size),
+ )
+ # Set new kernel in eval mode again
+ self.kernel.eval()
+ self.convert_flag = True
+
+ def forward(self, x):
+ x = self.kernel(x)
+ return x
+
+
+class Conv3d3x3x3DwBnAct(EfficientBlockBase):
+ """
+ Implements Conv3d (3x3x3 dw) + (optional) Bn + Activation layers.
+ The conv layer has fixed kernel_size = (3,3,3), depthwise, zero padding size of
+ (1,1,1), temporal stride = 1, dilation = 1
+
+ Input
+ |
+ ↓
+ conv3d (3x3x3 dw)
+ ↓
+ BatchNorm (optional)
+ ↓
+ Activation
+
+ Current implementation of this layer in QNNPACK is reasonably efficient.
+
+ convert_flag variable is to record whether the Conv3d3x3x3DwBnAct instance
+ has been converted; Conv3d3x3x3DwBnAct is in original form if convert_flag is false,
+ while it is in deployable form if convert_flag is true.
+
+ Args:
+ in_channels (int): number of channels for conv3d 3x3x3 dw.
+ spatial_stride (tuple length of 2): spatial stride for conv.
+ bias (bool): if true, use bias for conv.
+ activation (str): applies selected activation from supported_act_functions.
+ See activation_functions.py for more info about supported activations.
+ Currently ReLU ('relu'), Swish ('swish'), Hardswish ('hswish'), Identity
+ ('identity') are supported.
+ use_bn (bool): if true, use batchnorm.
+ norm_eps (float): epsilon for batchnorm.
+ norm_momentum (float): momentum for batchnorm.
+
+ Current implementation of this layer in Pytorch Mobile is efficient.
+ Sidenote: QNNPACK has best support for dw with 3x3 spatial kernel.
+ For other spatial kernels like 7x7 dw, the efficiency may be lower.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ spatial_stride: int = 1,
+ bias=False,
+ activation: str = "relu",
+ use_bn=True,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ ):
+ super().__init__()
+ kernel = OrderedDict()
+ conv_stride = (1, spatial_stride, spatial_stride)
+ kernel["conv"] = nn.Conv3d(
+ in_channels,
+ in_channels,
+ kernel_size=(3, 3, 3),
+ stride=conv_stride,
+ groups=in_channels,
+ padding=1,
+ bias=bias,
+ )
+ if use_bn:
+ kernel["bn"] = nn.BatchNorm3d(
+ in_channels, eps=norm_eps, momentum=norm_momentum
+ )
+ assert (
+ activation in supported_act_functions
+ ), f"Conv3d3x3x3DwBnAct: {activation} is not in supported_act_functions."
+ kernel["act"] = supported_act_functions[activation]()
+ self.kernel = nn.Sequential(kernel)
+
+ self.convert_flag = False
+
+ def convert(
+ self,
+ input_blob_size: Tuple,
+ convert_for_quantize: bool = False,
+ native_conv3d_op_qnnpack: bool = False,
+ **kwargs,
+ ):
+ """
+ Converts the block into efficient form.
+ For fp32 operation, or quantized but with older version of QNNPACK w/o native int8
+ Conv3d support, this function converts Conv3d into equivalent Conv2d for Pytorch
+ Mobile deployment.
+ For quantized operation on new version of QNNPACK with native int8 Conv3d, this
+ function will only apply operator fusion.
+ Args:
+ input_blob_size (tuple): blob size at the input of Conv3d3x3x3DwBnAct
+ instance during forward.
+ convert_for_quantize (bool): whether this module is intended to be quantized.
+ native_conv3d_op_qnnpack (bool): whether the QNNPACK version has native int8
+ Conv3d.
+ kwargs (any): any keyword argument (unused).
+ """
+ assert (
+ self.convert_flag is False
+ ), "Conv3d3x3x3DwBnAct: already converted, cannot be converted twice."
+ self.kernel.eval()
+ # Fuse conv and bn if bn exists.
+ if hasattr(self.kernel, "bn"):
+ self.kernel = fuse_modules(self.kernel, ["conv", "bn"])
+ # Convert Conv3d into equivalent Conv2d if using fp32 operation (convert_for_quantize
+ # is False) or not using QNNPACK native conv3d (native_conv3d_op_qnnpack is False)
+ if (convert_for_quantize is False) or (native_conv3d_op_qnnpack is False):
+ self.kernel.conv = _Conv3dTemporalKernel3Decomposed(
+ self.kernel.conv, input_blob_size[2:]
+ )
+ # Convert activatiopn function
+ self.kernel.act.convert(input_blob_size, **kwargs)
+ """
+ Since conv3d is converted into multiple conv2d,
+ will not fuse conv with act to keep arithmetic equivalency.
+ """
+ self.convert_flag = True
+ # Set new kernel in eval mode again
+ self.kernel.eval()
+
+ def forward(self, x):
+ x = self.kernel(x)
+ return x
+
+
+class Conv3dTemporalKernel1BnAct(EfficientBlockBase):
+ """
+ Implements Conv3d + Bn + Activation where Conv3d has temporal kernel of 1.
+ The conv layer has padding[0] = 0, stride[0] = 1, dilation[0] = 1.
+
+ Input
+ |
+ ↓
+ conv3d (1xkxk)
+ ↓
+ BatchNorm (optional)
+ ↓
+ Activation
+
+ Current implementation of this layer in QNNPACK is reasonably efficient
+ (not as efficient as Conv3dPwBnAct for 1x1x1 kernel).
+ Args:
+ in_channels (int): number of input channels for conv3d 1x1x1.
+ out_channels (int): number of output channels for conv3d 1x1x1.
+ bias (bool): if true, use bias for conv.
+ groups (int): number of groups for conv.
+ spstial_kernel (int): spatial kernel for conv3d.
+ spstial_stride (int): spatial stride for conv3d.
+ spatial_padding (int): spatial padding for conv3d.
+ spatial_dilation (int): spatial dilation for conv3d.
+ activation (str): applies selected activation from supported_act_functions.
+ See activation_functions.py for more info about supported activations.
+ Currently ReLU ('relu'), Swish ('swish'), Hardswish ('hswish'), Identity
+ ('identity') are supported.
+ use_bn (bool): if true, use batchnorm.
+ norm_eps (float): epsilon for batchnorm.
+ norm_momentum (float): momentum for batchnorm.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias=False,
+ groups: int = 1,
+ spatial_kernel: int = 1,
+ spatial_stride: int = 1,
+ spatial_padding: int = 0,
+ spatial_dilation: int = 1,
+ activation: str = "relu",
+ use_bn=True,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ ):
+ super().__init__()
+
+ kernel_size = (1, spatial_kernel, spatial_kernel)
+ stride = (1, spatial_stride, spatial_stride)
+ padding = (0, spatial_padding, spatial_padding)
+ dilation = (1, spatial_dilation, spatial_dilation)
+ kernel = OrderedDict()
+ kernel["conv"] = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ if use_bn:
+ kernel["bn"] = nn.BatchNorm3d(
+ out_channels, eps=norm_eps, momentum=norm_momentum
+ )
+ assert (
+ activation in supported_act_functions
+ ), f"Conv3dTemporalKernel1BnAct: {activation} is not in supported_act_functions."
+ kernel["act"] = supported_act_functions[activation]()
+ self.kernel = nn.Sequential(kernel)
+
+ self.convert_flag = False
+
+ def convert(
+ self,
+ input_blob_size: Tuple,
+ **kwargs,
+ ):
+ """
+ Converts Conv3d into equivalent Conv2d for QNNPACK deployment.
+ This conversion is done by first fuse conv3d with bn,
+ convert conv3d into equivalent conv2d,
+ and optionally fuse conv2d with relu.
+ Args:
+ input_blob_size (tuple): blob size at the input of
+ Conv3dTemporalKernel1BnAct instance during forward.
+ kwargs (any): any keyword argument (unused).
+ """
+ assert (
+ self.convert_flag is False
+ ), "Conv3dTemporalKernel1BnAct: already converted, cannot be converted again"
+ self.kernel.eval()
+ # First fuse conv and bn if bn exists.
+ if hasattr(self.kernel, "bn"):
+ self.kernel = fuse_modules(self.kernel, ["conv", "bn"])
+
+ self.kernel.conv = _Conv3dTemporalKernel1Decomposed(
+ self.kernel.conv, input_blob_size[2:]
+ )
+ # Convert activatiopn function
+ self.kernel.act.convert(input_blob_size, **kwargs)
+
+ self.convert_flag = True
+ # Set new kernel in eval mode again
+ self.kernel.eval()
+
+ def forward(self, x):
+ x = self.kernel(x)
+ return x
+
+
+class Conv3d3x1x1BnAct(EfficientBlockBase):
+ """
+ Implements Conv3d (3x1x1) + (optional) Bn + Activation for pointwise layers.
+ The conv layer has fixed kernel of (3, 1, 1), zero padding size of
+ (1, 0, 0), stride = (1, 1, 1), dilation = 1.
+
+ Input
+ |
+ ↓
+ conv3d (3x1x1)
+ ↓
+ BatchNorm (optional)
+ ↓
+ Activation
+
+ For regular convolution (i.e., groups=1), current implementation of this layer in
+ QNNPACK is reasonably efficient.
+ For depthwise convolution (i.e., groups=out_channels), current implementation of this
+ layer in QNNPACK is not efficient as Conv3d3x3x3DwBnRelu, as QNNPACK does not have
+ optimization for 1x1 depthwise convolution. The latencies of fp32 operation are similar
+ for Conv3d3x1x1BnAct and Conv3d3x3x3DwBnRelu, while with int8 operation Conv3d3x1x1BnAct
+ is 1.5X slower than Conv3d3x3x3DwBnRelu.
+
+ self.convert_flag property records whether the Conv3d3x1x1BnAct instance has been
+ converted; Conv3d3x1x1BnAct is in original form if convert_flag is false, while it
+ is in deployable form if convert_flag is true.
+
+ Args:
+ in_channels (int): number of input channels for conv3d 3x1x1.
+ out_channels (int): number of output channels for conv3d 3x1x1.
+ groups (int): number of groups for conv.
+ bias (bool): if true, use bias for conv.
+ activation (str): applies selected activation from supported_act_functions.
+ See activation_functions.py for more info about supported activations.
+ Currently ReLU ('relu'), Swish ('swish'), Hardswish ('hswish'), Identity
+ ('identity') are supported.
+ use_bn (bool): if true, use batchnorm.
+ norm_eps (float): epsilon for batchnorm.
+ norm_momentum (float): momentum for batchnorm.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ groups: int = 1,
+ bias=False,
+ activation: str = "relu",
+ use_bn=True,
+ norm_eps=1e-5,
+ norm_momentum=0.1,
+ ):
+ super().__init__()
+ kernel = OrderedDict()
+ kernel["conv"] = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=(3, 1, 1),
+ groups=groups,
+ padding=(1, 0, 0),
+ bias=bias,
+ )
+
+ if groups == out_channels:
+ logging.warning(
+ (
+ "Conv3d3x1x1BnAct has low efficiency for depthwise conv. "
+ "Consider using Conv3d3x3x3DwBnRelu instead."
+ )
+ )
+
+ if use_bn:
+ kernel["bn"] = nn.BatchNorm3d(
+ out_channels, eps=norm_eps, momentum=norm_momentum
+ )
+ assert (
+ activation in supported_act_functions
+ ), f"Conv3d3x1x1BnAct: {activation} is not in supported_act_functions."
+ kernel["act"] = supported_act_functions[activation]()
+ self.kernel = nn.Sequential(kernel)
+ self.convert_flag = False
+
+ def convert(
+ self,
+ input_blob_size,
+ **kwargs,
+ ):
+ """
+ Converts Conv3d into equivalent Conv2d for Pytorch Mobile deployment
+
+ """
+ assert (
+ self.convert_flag is False
+ ), "Conv3d3x1x1BnAct: already converted, cannot be converted twice"
+ self.kernel.eval()
+ # Fuse conv and bn if bn exists.
+ if hasattr(self.kernel, "bn"):
+ self.kernel = fuse_modules(self.kernel, ["conv", "bn"])
+ self.kernel.conv = _Conv3dTemporalKernel3Decomposed(
+ self.kernel.conv, input_blob_size[2:]
+ )
+ # Convert activation function
+ self.kernel.act.convert(input_blob_size, **kwargs)
+ # Since conv3d is converted into multiple conv2d, will not fuse conv with relu
+ # to keep arithmetic equivalency.
+ self.convert_flag = True
+ self.kernel.eval()
+
+ def forward(self, x):
+ x = self.kernel(x)
+ return x
+
+
+class Conv3d5x1x1BnAct(EfficientBlockBase):
+ """
+ Implements Conv3d (5x1x1) + (optional) Bn + Activation for pointwise layers.
+ The conv layer has fixed kernel of (5, 1, 1), zero padding size of
+ (2, 0, 0), stride = (1, 1, 1), dilation = 1.
+
+ Input
+ |
+ ↓
+ conv3d (5x1x1)
+ ↓
+ BatchNorm (optional)
+ ↓
+ Activation
+
+ For regular convolution (i.e., groups=1), current implementation of this layer in
+ QNNPACK is reasonably efficient.
+
+ self.convert_flag property records whether the Conv3d5x1x1BnAct instance has been
+ converted; Conv3d5x1x1BnAct is in original form if convert_flag is false, while it
+ is in deployable form if convert_flag is true.
+
+ Args:
+ in_channels (int): number of input channels for conv3d 3x1x1.
+ out_channels (int): number of output channels for conv3d 3x1x1.
+ groups (int): number of groups for conv.
+ bias (bool): if true, use bias for conv.
+ activation (str): applies selected activation from supported_act_functions.
+ See activation_functions.py for more info about supported activations.
+ Currently ReLU ('relu'), Swish ('swish'), Hardswish ('hswish'), Identity
+ ('identity') are supported.
+ use_bn (bool): if true, use batchnorm.
+ norm_eps (float): epsilon for batchnorm.
+ norm_momentum (float): momentum for batchnorm.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ groups: int = 1,
+ bias=False,
+ activation: str = "relu",
+ use_bn=True,
+ norm_eps=1e-5,
+ norm_momentum=0.1,
+ ):
+ super().__init__()
+ kernel = OrderedDict()
+ kernel["conv"] = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=(5, 1, 1),
+ groups=groups,
+ padding=(2, 0, 0),
+ bias=bias,
+ )
+
+ if use_bn:
+ kernel["bn"] = nn.BatchNorm3d(
+ out_channels, eps=norm_eps, momentum=norm_momentum
+ )
+ assert (
+ activation in supported_act_functions
+ ), f"Conv3d5x1x1BnAct: {activation} is not in supported_act_functions."
+ kernel["act"] = supported_act_functions[activation]()
+ self.kernel = nn.Sequential(kernel)
+ self.convert_flag = False
+
+ def convert(self, input_blob_size, **kwargs):
+ """
+ Converts Conv3d into equivalent Conv2d for Pytorch Mobile deployment
+
+ """
+ assert (
+ self.convert_flag is False
+ ), "Conv3d5x1x1BnAct: already converted, cannot be converted twice"
+ self.kernel.eval()
+ # Fuse conv and bn if bn exists.
+ if hasattr(self.kernel, "bn"):
+ self.kernel = fuse_modules(self.kernel, ["conv", "bn"])
+ self.kernel.conv = _Conv3dTemporalKernel5Decomposed(
+ self.kernel.conv, input_blob_size[2:]
+ )
+ # Convert activatiopn function
+ self.kernel.act.convert(input_blob_size, **kwargs)
+ # Since conv3d is converted into multiple conv2d, will not fuse conv with relu
+ # to keep arithmetic equivalency.
+ self.convert_flag = True
+ self.kernel.eval()
+
+ def forward(self, x):
+ x = self.kernel(x)
+ return x
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py
new file mode 100644
index 0000000000000000000000000000000000000000..83421d4fca7bab3af1e69c4a4b44fbeadb92cb4f
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py
@@ -0,0 +1,26 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import torch.nn as nn
+from pytorchvideo.accelerator.efficient_blocks.no_op_convert_block import (
+ NoOpConvertBlock,
+)
+
+
+class FullyConnected(NoOpConvertBlock):
+ """
+ Implements fully connected layer. This operator is natively supported by QNNPACK for
+ mobile CPU with good efficiency, and no change is made upon convert().
+ Args:
+ in_features (int): input channels for FC layer.
+ out_features (int): output channels for FC layer.
+ bias (bool): if True, bias is applied
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ ):
+
+ super().__init__(model=nn.Linear(in_features, out_features, bias=bias))
diff --git a/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/pool.py b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e92ee9e4257f995b3243f5fce0eeeb2ad2b2536
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/accelerator/mobile_cpu/pool.py
@@ -0,0 +1,113 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Tuple, Union
+
+import torch.nn as nn
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+from pytorchvideo.accelerator.efficient_blocks.no_op_convert_block import (
+ NoOpConvertBlock,
+)
+
+
+class AdaptiveAvgPool3dOutSize1(EfficientBlockBase):
+ """
+ Implements AdaptiveAvgPool3d with output (T, H, W) = (1, 1, 1). This operator has
+ better efficiency than AdaptiveAvgPool for mobile CPU.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.pool = nn.AdaptiveAvgPool3d(1)
+ self.convert_flag = False
+
+ def convert(self, input_blob_size: Tuple, **kwargs):
+ """
+ Converts AdaptiveAvgPool into AvgPool with constant kernel size for better
+ efficiency.
+ Args:
+ input_blob_size (tuple): blob size at the input of
+ AdaptiveAvgPool3dOutSize1 instance during forward.
+ kwargs (any): any keyword argument (unused).
+ """
+ assert (
+ self.convert_flag is False
+ ), "AdaptiveAvgPool3dOutSize1: already converted, cannot be converted again"
+ kernel_size = input_blob_size[2:]
+ self.pool = nn.AvgPool3d(kernel_size)
+ self.convert_flag = True
+
+ def forward(self, x):
+ return self.pool(x)
+
+
+class AdaptiveAvgPool2dOutSize1(EfficientBlockBase):
+ """
+ Implements AdaptiveAvgPool2d with output (H, W) = (1, 1). This operator has
+ better efficiency than AdaptiveAvgPool for mobile CPU.
+ """
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.convert_flag = False
+
+ def convert(self, input_blob_size: Tuple, **kwargs):
+ """
+ Converts AdaptiveAvgPool into AvgPool with constant kernel size for better
+ efficiency.
+ Args:
+ input_blob_size (tuple): blob size at the input of
+ AdaptiveAvgPool2dOutSize1 instance during forward.
+ kwargs (any): any keyword argument (unused).
+ """
+ assert (
+ self.convert_flag is False
+ ), "AdaptiveAvgPool2dOutSize1: already converted, cannot be converted again"
+ kernel_size = input_blob_size[2:]
+ self.pool = nn.AvgPool2d(kernel_size)
+ self.convert_flag = True
+
+ def forward(self, x):
+ return self.pool(x)
+
+
+class AdaptiveAvgPool3d(NoOpConvertBlock):
+ """
+ Implements AdaptiveAvgPool3d with any output (T, H, W) size. This operator is
+ supported by QNNPACK for mobile CPU with resonable efficiency, and no change is
+ made upon convert(). If the output (T, H, W) = (1, 1, 1), use AdaptiveAvgPool3dOutSize1
+ for better efficiency.
+ Args:
+ output_size (int or tuple): when it is a tuple, the output (T, H, W) of pool
+ will be equal to output_size. When it is an int, the output (T, H, W)
+ will be equal to (output_size, output_size, output_size).
+ """
+
+ def __init__(
+ self,
+ output_size: Union[int, Tuple],
+ ):
+ super().__init__(model=nn.AdaptiveAvgPool3d(output_size))
+
+
+class AdaptiveAvgPool2d(NoOpConvertBlock):
+ """
+ Implements AdaptiveAvgPool2d with any output (H, W) size. This operator is
+ supported by QNNPACK for mobile CPU with resonable efficiency, and no change is
+ made upon convert(). If the output (H, W) = (1, 1), use AdaptiveAvgPool2dOutSize1
+ for better efficiency.
+ Args:
+ output_size (int or tuple): when it is a tuple, the output (H, W) of pool
+ will be equal to output_size. When it is an int, the output (H, W)
+ will be equal to (output_size, output_size).
+ """
+
+ def __init__(
+ self,
+ output_size: Union[int, Tuple],
+ ):
+ super().__init__(model=nn.AdaptiveAvgPool2d(output_size))
diff --git a/code/pytorchvideo/pytorchvideo/layers/attention.py b/code/pytorchvideo/pytorchvideo/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9f65e2689c2827c1a61ee120641c398de08a78
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/attention.py
@@ -0,0 +1,757 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, List, Optional, Tuple
+
+import numpy
+import torch
+
+try:
+ import torch.fx
+except Exception as _:
+ pass
+import torch.nn as nn
+from torch.nn.common_types import _size_3_t
+
+from .drop_path import DropPath
+
+
+@torch.fx.wrap
+def _unsqueeze_dims_fx(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ tensor_dim = tensor.ndim
+ if tensor_dim == 4:
+ pass
+ elif tensor_dim == 3:
+ tensor = tensor.unsqueeze(1)
+ else:
+ raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
+ return tensor, tensor_dim
+
+
+@torch.jit.script
+def _unsqueeze_dims_jit(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ return _unsqueeze_dims_fx(tensor)
+
+
+@torch.fx.wrap
+def _squeeze_dims_fx(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor:
+ if tensor_dim == 4:
+ pass
+ elif tensor_dim == 3:
+ tensor = tensor.squeeze(1)
+ else:
+ raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
+ return tensor
+
+
+@torch.jit.script
+def _squeeze_dims_jit(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor:
+ return _squeeze_dims_fx(tensor, tensor_dim)
+
+
+class Mlp(nn.Module):
+ """
+ A MLP block that contains two linear layers with a normalization layer. The MLP
+ block is used in a transformer model after the attention block.
+
+ ::
+
+ Linear (in_features, hidden_features)
+ ↓
+ Normalization (act_layer)
+ ↓
+ Dropout (p=dropout_rate)
+ ↓
+ Linear (hidden_features, out_features)
+ ↓
+ Dropout (p=dropout_rate)
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable = nn.GELU,
+ dropout_rate: float = 0.0,
+ bias_on: bool = True,
+ ) -> None:
+ """
+ Args:
+ in_features (int): Input feature dimension.
+ hidden_features (Optional[int]): Hidden feature dimension. By default,
+ hidden feature is set to input feature dimension.
+ out_features (Optional[int]): Output feature dimension. By default, output
+ features dimension is set to input feature dimension.
+ act_layer (Callable): Activation layer used after the first linear layer.
+ dropout_rate (float): Dropout rate after each linear layer. Dropout is not used
+ by default.
+ """
+ super().__init__()
+ self.dropout_rate = dropout_rate
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias_on)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias_on)
+
+ if self.dropout_rate > 0.0:
+ self.dropout = nn.Dropout(dropout_rate)
+ else:
+ self.dropout = nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (tensor): Input tensor.
+ """
+ x = self.fc1(x)
+ x = self.act(x)
+ if self.dropout_rate > 0.0:
+ x = self.dropout(x)
+ x = self.fc2(x)
+ if self.dropout_rate > 0.0:
+ x = self.dropout(x)
+ return x
+
+
+class _AttentionPool(torch.nn.Module):
+ def __init__(
+ self,
+ pool: Optional[torch.nn.Module],
+ has_cls_embed: bool,
+ norm: Optional[torch.nn.Module],
+ ) -> None:
+ """Apply pool to a flattened input (given pool operation and the unflattened shape).
+
+
+ Input
+ ↓
+ Reshape
+ ↓
+ Pool
+ ↓
+ Reshape
+ ↓
+ Norm
+
+
+ Params:
+ pool (Optional[Callable]): Pool operation that is applied to the input tensor.
+ If pool is none, return the input tensor.
+ has_cls_embed (bool): Whether the input tensor contains cls token. Pool
+ operation excludes cls token.
+ norm: (Optional[Callable]): Optional normalization operation applied to
+ tensor after pool.
+ """
+ super().__init__()
+ self.has_pool = pool is not None
+ self.pool = pool if pool is not None else torch.nn.Identity()
+
+ self.has_cls_embed = has_cls_embed
+ if norm is not None:
+ self.norm_before_pool = isinstance(
+ norm, (torch.nn.BatchNorm3d, torch.nn.Identity)
+ )
+ self.has_norm = True
+ self.norm = norm
+ else:
+ self.norm_before_pool = False
+ self.has_norm = False
+ self.norm = torch.nn.Identity()
+
+ def forward(
+ self, tensor: torch.Tensor, thw_shape: List[int]
+ ) -> Tuple[torch.Tensor, List[int]]:
+ """
+ Args:
+ tensor (torch.Tensor): Input tensor.
+ thw_shape (List): The shape of the input tensor (before flattening).
+
+ Returns:
+ tensor (torch.Tensor): Input tensor after pool.
+ thw_shape (List[int]): Output tensor shape (before flattening).
+ """
+ if not self.has_pool:
+ return tensor, thw_shape
+ tensor_dim = tensor.ndim
+
+ if torch.jit.is_scripting():
+ tensor, tensor_dim = _unsqueeze_dims_jit(tensor)
+ else:
+ tensor, tensor_dim = _unsqueeze_dims_fx(tensor)
+
+ cls_tok: torch.Tensor = torch.tensor(0) # For typing/torchscriptability
+ if self.has_cls_embed:
+ cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
+
+ B, N, L, C = tensor.shape
+ T, H, W = thw_shape
+ tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()
+
+ if self.norm_before_pool:
+ # If use BN, we apply norm before pooling instead of after pooling.
+ tensor = self.norm(tensor)
+ # We also empirically find that adding a GELU here is beneficial.
+ tensor = torch.nn.functional.gelu(tensor)
+
+ tensor = self.pool(tensor)
+
+ thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
+ L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
+ tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
+ if self.has_cls_embed:
+ tensor = torch.cat((cls_tok, tensor), dim=2)
+ if self.has_norm and not self.norm_before_pool:
+ tensor = self.norm(tensor)
+
+ if torch.jit.is_scripting():
+ tensor = _squeeze_dims_jit(tensor, tensor_dim)
+ else:
+ tensor = _squeeze_dims_fx(tensor, tensor_dim)
+
+ return tensor, thw_shape
+
+
+class MultiScaleAttention(nn.Module):
+ """
+ Implementation of a multiscale attention block. Compare to a conventional attention
+ block, a multiscale attention block optionally supports pooling (either
+ before or after qkv projection). If pooling is not used, a multiscale attention
+ block is equivalent to a conventional attention block.
+
+ ::
+ Input
+ |
+ |----------------|-----------------|
+ ↓ ↓ ↓
+ Linear Linear Linear
+ [dim expand] [dim expand] [dim expand]
+ & & &
+ Pool (Q) Pool (K) Pool (V)
+ → -------------- ← |
+ ↓ |
+ MatMul & Scale |
+ ↓ |
+ Softmax |
+ → ----------------------- ←
+ ↓
+ MatMul & Scale
+ ↓
+ DropOut
+ """
+
+ _version = 3
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int = None,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ dropout_rate: float = 0.0,
+ kernel_q: _size_3_t = (1, 1, 1),
+ kernel_kv: _size_3_t = (1, 1, 1),
+ stride_q: _size_3_t = (1, 1, 1),
+ stride_kv: _size_3_t = (1, 1, 1),
+ norm_layer: Callable = nn.LayerNorm,
+ has_cls_embed: bool = True,
+ pool_mode: str = "conv",
+ pool_first: bool = False,
+ residual_pool: bool = True,
+ depthwise_conv: bool = True,
+ bias_on: bool = True,
+ separate_qkv: bool = True,
+ ) -> None:
+ """
+ Args:
+ dim (int): Input feature dimension.
+ dim_out (int): Output feature dimension
+ num_heads (int): Number of heads in the attention layer.
+ qkv_bias (bool): If set to False, the qkv layer will not learn an additive
+ bias. Default: False.
+ dropout_rate (float): Dropout rate.
+ kernel_q (_size_3_t): Pooling kernel size for q. If both pooling kernel
+ size and pooling stride size are 1 for all the dimensions, pooling is
+ disabled.
+ kernel_kv (_size_3_t): Pooling kernel size for kv. If both pooling kernel
+ size and pooling stride size are 1 for all the dimensions, pooling is
+ disabled.
+ stride_q (_size_3_t): Pooling kernel stride for q.
+ stride_kv (_size_3_t): Pooling kernel stride for kv.
+ norm_layer (nn.Module): Normalization layer used after pooling.
+ has_cls_embed (bool): If set to True, the first token of the input tensor
+ should be a cls token. Otherwise, the input tensor does not contain a
+ cls token. Pooling is not applied to the cls token.
+ pool_mode (str): Pooling mode. Option includes "conv" (learned pooling), "avg"
+ (average pooling), and "max" (max pooling).
+ pool_first (bool): If set to True, pool is applied before qkv projection.
+ Otherwise, pool is applied after qkv projection. Default: False.
+ residual_pool (bool): If set to True, use Improved Multiscale Vision
+ Transformer's pooling residual connection.
+ depthwise_conv (bool): Whether use depthwise or full convolution for pooling.
+ bias_on (bool): Whether use biases for linear layers.
+ separate_qkv (bool): Whether to use separate or one layer for qkv projections.
+ """
+
+ super().__init__()
+ assert pool_mode in ["conv", "avg", "max"]
+
+ self.pool_first = pool_first
+ self.dropout_rate = dropout_rate
+ self.num_heads = num_heads
+ dim_out = dim if not dim_out else dim_out
+ self.dim_out = dim_out
+ head_dim = dim_out // num_heads
+ self.scale = head_dim**-0.5
+ self.has_cls_embed = has_cls_embed
+ self.residual_pool = residual_pool
+ self.separate_qkv = separate_qkv
+ padding_q = [int(q // 2) for q in kernel_q]
+ padding_kv = [int(kv // 2) for kv in kernel_kv]
+
+ # Set placeholders for torchscriptability, may not be actually used
+ self.q = self.k = self.v = self.qkv = nn.Identity()
+ if self.pool_first or self.separate_qkv:
+ self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
+ self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
+ else:
+ self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim_out, dim_out, bias=True if bias_on else False)
+
+ if dropout_rate > 0.0:
+ self.proj_drop = nn.Dropout(dropout_rate)
+ else:
+ self.proj_drop = nn.Identity()
+
+ # Skip pooling with kernel and stride size of (1, 1, 1).
+ if (
+ kernel_q is not None
+ and self._prod(kernel_q) == 1
+ and self._prod(stride_q) == 1
+ ):
+ kernel_q = None
+ if (
+ kernel_kv is not None
+ and self._prod(kernel_kv) == 1
+ and self._prod(stride_kv) == 1
+ ):
+ kernel_kv = None
+
+ if pool_mode in ("avg", "max"):
+ pool_op = nn.MaxPool3d if pool_mode == "max" else nn.AvgPool3d
+ self.pool_q = (
+ pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
+ if kernel_q is not None
+ else None
+ )
+ self.pool_k = (
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
+ if kernel_kv is not None
+ else None
+ )
+ self.pool_v = (
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
+ if kernel_kv is not None
+ else None
+ )
+ elif pool_mode == "conv":
+ if self.pool_first:
+ dim_conv = dim // num_heads
+ else:
+ dim_conv = dim_out // num_heads
+ self.pool_q = (
+ nn.Conv3d(
+ dim_conv,
+ dim_conv,
+ kernel_q,
+ stride=stride_q,
+ padding=padding_q,
+ groups=dim_conv if depthwise_conv else 1,
+ bias=False,
+ )
+ if kernel_q is not None
+ else None
+ )
+ self.norm_q = norm_layer(dim_conv) if kernel_q is not None else None
+ self.pool_k = (
+ nn.Conv3d(
+ dim_conv,
+ dim_conv,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=dim_conv if depthwise_conv else 1,
+ bias=False,
+ )
+ if kernel_kv is not None
+ else None
+ )
+ self.norm_k = norm_layer(dim_conv) if kernel_kv is not None else None
+ self.pool_v = (
+ nn.Conv3d(
+ dim_conv,
+ dim_conv,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=dim_conv if depthwise_conv else 1,
+ bias=False,
+ )
+ if kernel_kv is not None
+ else None
+ )
+ self.norm_v = norm_layer(dim_conv) if kernel_kv is not None else None
+ else:
+ raise NotImplementedError(f"Unsupported model {pool_mode}")
+
+ # Will not be used if `separate_qkv == True`
+ self._attention_pool_q = _AttentionPool(
+ self.pool_q,
+ has_cls_embed=self.has_cls_embed,
+ norm=getattr(self, "norm_q", None),
+ )
+ self._attention_pool_k = _AttentionPool(
+ self.pool_k,
+ has_cls_embed=self.has_cls_embed,
+ norm=getattr(self, "norm_k", None),
+ )
+ self._attention_pool_v = _AttentionPool(
+ self.pool_v,
+ has_cls_embed=self.has_cls_embed,
+ norm=getattr(self, "norm_v", None),
+ )
+
+ def _qkv_proj(
+ self,
+ q: torch.Tensor,
+ q_size: int,
+ k: torch.Tensor,
+ k_size: int,
+ v: torch.Tensor,
+ v_size: int,
+ batch_size: int,
+ chan_size: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ q = (
+ self.q(q)
+ .reshape(batch_size, q_size, self.num_heads, chan_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ k = (
+ self.k(k)
+ .reshape(batch_size, k_size, self.num_heads, chan_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ v = (
+ self.v(v)
+ .reshape(batch_size, v_size, self.num_heads, chan_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ return q, k, v
+
+ def _qkv_pool(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ thw_shape: List[int],
+ ) -> Tuple[
+ torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int]
+ ]:
+ q, q_shape = self._attention_pool_q(q, thw_shape)
+ k, k_shape = self._attention_pool_k(k, thw_shape)
+ v, v_shape = self._attention_pool_v(v, thw_shape)
+ return q, q_shape, k, k_shape, v, v_shape
+
+ def _get_qkv_length(
+ self,
+ q_shape: List[int],
+ k_shape: List[int],
+ v_shape: List[int],
+ ) -> Tuple[int, int, int]:
+ q_N = self._prod(q_shape) + 1 if self.has_cls_embed else self._prod(q_shape)
+ k_N = self._prod(k_shape) + 1 if self.has_cls_embed else self._prod(k_shape)
+ v_N = self._prod(v_shape) + 1 if self.has_cls_embed else self._prod(v_shape)
+ return q_N, k_N, v_N
+
+ def _prod(self, shape: List[int]) -> int:
+ """Torchscriptable version of `numpy.prod`. Note that `_prod([]) == 1`"""
+ p: int = 1
+ for dim in shape:
+ p *= dim
+ return p
+
+ def _reshape_qkv_to_seq(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ q_N: int,
+ v_N: int,
+ k_N: int,
+ B: int,
+ C: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ q = q.permute(0, 2, 1, 3).reshape(B, q_N, C)
+ v = v.permute(0, 2, 1, 3).reshape(B, v_N, C)
+ k = k.permute(0, 2, 1, 3).reshape(B, k_N, C)
+ return q, k, v
+
+ def forward(
+ self, x: torch.Tensor, thw_shape: List[int]
+ ) -> Tuple[torch.Tensor, List[int]]:
+ """
+ Args:
+ x (torch.Tensor): Input tensor.
+ thw_shape (List): The shape of the input tensor (before flattening).
+ """
+
+ B, N, C = x.shape
+ if self.pool_first:
+ x = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ q = k = v = x
+ q, q_shape, k, k_shape, v, v_shape = self._qkv_pool(q, k, v, thw_shape)
+ q_N, k_N, v_N = self._get_qkv_length(q_shape, k_shape, v_shape)
+ q, k, v = self._reshape_qkv_to_seq(q, k, v, q_N, v_N, k_N, B, C)
+ q, k, v = self._qkv_proj(q, q_N, k, k_N, v, v_N, B, self.dim_out)
+ else:
+ if self.separate_qkv:
+ q = k = v = x
+ q, k, v = self._qkv_proj(q, N, k, N, v, N, B, self.dim_out)
+ else:
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, -1)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q, q_shape, k, k_shape, v, v_shape = self._qkv_pool(q, k, v, thw_shape)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+
+ N = q.shape[2]
+
+ if self.residual_pool:
+ x = (attn @ v + q).transpose(1, 2).reshape(B, -1, self.dim_out)
+ else:
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dim_out)
+
+ x = self.proj(x)
+ if self.dropout_rate > 0.0:
+ x = self.proj_drop(x)
+ return x, q_shape
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ for layer in ["pool", "norm"]:
+ for pattern in ["q", "k", "v"]:
+ for type in ["weight", "bias"]:
+ old_key = f"{prefix}{layer}_{pattern}.{type}"
+ new_key = f"{prefix}_attention_pool_{pattern}.{layer}.{type}"
+ if old_key in state_dict:
+ state_dict[new_key] = state_dict[old_key]
+
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+
+class MultiScaleBlock(nn.Module):
+ """
+ Implementation of a multiscale vision transformer block. Each block contains a
+ multiscale attention layer and a Mlp layer.
+
+ ::
+
+
+ Input
+ |-------------------+
+ ↓ |
+ Norm |
+ ↓ |
+ MultiScaleAttention [Proj: dim expand]
+ [dim expand] Pool
+ ↓ |
+ DropPath |
+ ↓ |
+ Summation ←-------------+
+ |
+ |-------------------+
+ ↓ |
+ Norm |
+ ↓ |
+ Mlp [Proj: dim expand]
+ [dim expand] |
+ ↓ |
+ DropPath |
+ ↓ |
+ Summation ←------------+
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ dropout_rate: float = 0.0,
+ droppath_rate: float = 0.0,
+ act_layer: nn.Module = nn.GELU,
+ norm_layer: nn.Module = nn.LayerNorm,
+ attn_norm_layer: nn.Module = nn.LayerNorm,
+ dim_mul_in_att: bool = False,
+ kernel_q: _size_3_t = (1, 1, 1),
+ kernel_kv: _size_3_t = (1, 1, 1),
+ stride_q: _size_3_t = (1, 1, 1),
+ stride_kv: _size_3_t = (1, 1, 1),
+ pool_mode: str = "conv",
+ has_cls_embed: bool = True,
+ pool_first: bool = False,
+ residual_pool: bool = False,
+ depthwise_conv: bool = True,
+ bias_on: bool = True,
+ separate_qkv: bool = True,
+ ) -> None:
+ """
+ Args:
+ dim (int): Input feature dimension.
+ dim_out (int): Output feature dimension.
+ num_heads (int): Number of heads in the attention layer.
+ mlp_ratio (float): Mlp ratio which controls the feature dimension in the
+ hidden layer of the Mlp block.
+ qkv_bias (bool): If set to False, the qkv layer will not learn an additive
+ bias. Default: False.
+ dropout_rate (float): DropOut rate. If set to 0, DropOut is disabled.
+ droppath_rate (float): DropPath rate. If set to 0, DropPath is disabled.
+ act_layer (nn.Module): Activation layer used in the Mlp layer.
+ norm_layer (nn.Module): Normalization layer.
+ attn_norm_layer (nn.Module): Normalization layer in the attention module.
+ dim_mul_in_att (bool): If set to True, dimension expansion happens inside
+ the attention module, otherwise it happens in the Mlp block. Default: False.
+ kernel_q (_size_3_t): Pooling kernel size for q. If pooling kernel size is
+ 1 for all the dimensions, pooling is not used (by default).
+ kernel_kv (_size_3_t): Pooling kernel size for kv. If pooling kernel size
+ is 1 for all the dimensions, pooling is not used. By default, pooling
+ is disabled.
+ stride_q (_size_3_t): Pooling kernel stride for q.
+ stride_kv (_size_3_t): Pooling kernel stride for kv.
+ pool_mode (str): Pooling mode. Option includes "conv" (learned pooling), "avg"
+ (average pooling), and "max" (max pooling).
+ has_cls_embed (bool): If set to True, the first token of the input tensor
+ should be a cls token. Otherwise, the input tensor does not contain a
+ cls token. Pooling is not applied to the cls token.
+ pool_first (bool): If set to True, pool is applied before qkv projection.
+ Otherwise, pool is applied after qkv projection. Default: False.
+ residual_pool (bool): If set to True, use Improved Multiscale Vision
+ Transformer's pooling residual connection.
+ depthwise_conv (bool): Whether use depthwise or full convolution for pooling.
+ bias_on (bool): Whether use biases for linear layers.
+ separate_qkv (bool): Whether to use separate or one layer for qkv projections.
+ """
+ super().__init__()
+ self.dim = dim
+ self.dim_out = dim_out
+ self.norm1 = norm_layer(dim)
+ self.dim_mul_in_att = dim_mul_in_att
+ self.norm1_is_batchnorm_1d = isinstance(self.norm1, nn.BatchNorm1d)
+ kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
+ stride_skip = stride_q
+ padding_skip = [int(skip // 2) for skip in kernel_skip]
+ att_dim = dim_out if dim_mul_in_att else dim
+ self.attn = MultiScaleAttention(
+ dim=dim,
+ dim_out=att_dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ dropout_rate=dropout_rate,
+ kernel_q=kernel_q,
+ kernel_kv=kernel_kv,
+ stride_q=stride_q,
+ stride_kv=stride_kv,
+ norm_layer=attn_norm_layer,
+ has_cls_embed=has_cls_embed,
+ pool_mode=pool_mode,
+ pool_first=pool_first,
+ residual_pool=residual_pool,
+ bias_on=bias_on,
+ depthwise_conv=depthwise_conv,
+ separate_qkv=separate_qkv,
+ )
+ self.drop_path = (
+ DropPath(droppath_rate) if droppath_rate > 0.0 else nn.Identity()
+ )
+ self.norm2 = norm_layer(att_dim)
+ self.norm2_is_batchnorm_1d = isinstance(self.norm2, nn.BatchNorm1d)
+ mlp_hidden_dim = int(att_dim * mlp_ratio)
+ self.has_cls_embed = has_cls_embed
+ self.mlp = Mlp(
+ in_features=att_dim,
+ hidden_features=mlp_hidden_dim,
+ out_features=dim_out,
+ act_layer=act_layer,
+ dropout_rate=dropout_rate,
+ bias_on=bias_on,
+ )
+ if dim != dim_out:
+ self.proj = nn.Linear(dim, dim_out, bias=bias_on)
+ else:
+ self.proj = nn.Identity()
+
+ self.pool_skip = (
+ nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False)
+ if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1
+ else None
+ )
+ self._attention_pool = _AttentionPool(
+ self.pool_skip, has_cls_embed=self.has_cls_embed, norm=None
+ )
+
+ def forward(
+ self, x: torch.Tensor, thw_shape: List[int]
+ ) -> Tuple[torch.Tensor, List[int]]:
+ """
+ Args:
+ x (torch.Tensor): Input tensor.
+ thw_shape (List): The shape of the input tensor (before flattening).
+ """
+
+ x_norm = (
+ self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1)
+ if self.norm1_is_batchnorm_1d
+ else self.norm1(x)
+ )
+ x_block, thw_shape_new = self.attn(x_norm, thw_shape)
+ if self.dim_mul_in_att and self.dim != self.dim_out:
+ x = self.proj(x_norm)
+ x_res, _ = self._attention_pool(x, thw_shape)
+ x = x_res + self.drop_path(x_block)
+ x_norm = (
+ self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1)
+ if self.norm2_is_batchnorm_1d
+ else self.norm2(x)
+ )
+ x_mlp = self.mlp(x_norm)
+ if not self.dim_mul_in_att and self.dim != self.dim_out:
+ x = self.proj(x_norm)
+ x = x + self.drop_path(x_mlp)
+ return x, thw_shape_new
diff --git a/code/pytorchvideo/pytorchvideo/layers/attention_torchscript.py b/code/pytorchvideo/pytorchvideo/layers/attention_torchscript.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bfb080f7a64a0a3d5470ba3f8253d2852b14a6e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/attention_torchscript.py
@@ -0,0 +1,634 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import List, Optional, Tuple
+
+import numpy
+import torch
+import torch.fx
+import torch.nn as nn
+from torch.nn.common_types import _size_3_t
+
+from .drop_path import DropPath
+
+
+class Mlp(nn.Module):
+ """
+ A MLP block that contains two linear layers with a normalization layer. The MLP
+ block is used in a transformer model after the attention block.
+
+ ::
+
+ Linear (in_features, hidden_features)
+ ↓
+ Normalization (act_layer)
+ ↓
+ Dropout (p=dropout_rate)
+ ↓
+ Linear (hidden_features, out_features)
+ ↓
+ Dropout (p=dropout_rate)
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer=nn.GELU,
+ dropout_rate: float = 0.0,
+ bias_on: bool = True,
+ ) -> None:
+ """
+ Args:
+ in_features (int): Input feature dimension.
+ hidden_features (Optional[int]): Hidden feature dimension. By default,
+ hidden feature is set to input feature dimension.
+ out_features (Optional[int]): Output feature dimension. By default, output
+ features dimension is set to input feature dimension.
+ act_layer (Callable): Activation layer used after the first linear layer.
+ dropout_rate (float): Dropout rate after each linear layer. Dropout is not used
+ by default.
+ """
+ super().__init__()
+ self.dropout_rate = dropout_rate
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias_on)
+
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias_on)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (tensor): Input tensor.
+ """
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+
+@torch.fx.wrap
+def _unsqueeze_dims_fx(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ tensor_dim = tensor.ndim
+ if tensor_dim == 4:
+ pass
+ elif tensor_dim == 3:
+ tensor = tensor.unsqueeze(1)
+ else:
+ raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
+ return tensor, tensor_dim
+
+
+@torch.jit.script
+def _unsqueeze_dims_jit(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ return _unsqueeze_dims_fx(tensor)
+
+
+@torch.fx.wrap
+def _squeeze_dims_fx(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor:
+ if tensor_dim == 4:
+ pass
+ elif tensor_dim == 3:
+ tensor = tensor.squeeze(1)
+ else:
+ raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
+ return tensor
+
+
+@torch.jit.script
+def _squeeze_dims_jit(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor:
+ return _squeeze_dims_fx(tensor, tensor_dim)
+
+
+def _pre_attention_pool(
+ tensor: torch.Tensor,
+ thw_shape: List[int],
+) -> Tuple[torch.Tensor, Tuple[int, int, int, int, int, int, int, int]]:
+ """
+ Apply pool to a flattened input (given pool operation and the unflattened shape).
+
+
+ Input
+ ↓
+ Reshape
+ ↓
+ Pool
+ ↓
+ Reshape
+ ↓
+ Norm
+
+
+ Args:
+ tensor (torch.Tensor): Input tensor.
+ pool (Optional[Callable]): Pool operation that is applied to the input tensor.
+ If pool is none, return the input tensor.
+ thw_shape (List): The shape of the input tensor (before flattening).
+ has_cls_embed (bool): Whether the input tensor contains cls token. Pool
+ operation excludes cls token.
+ norm: (Optional[Callable]): Optional normalization operation applied to
+ tensor after pool.
+
+ Returns:
+ tensor (torch.Tensor): Input tensor after pool.
+ thw_shape (List[int]): Output tensor shape (before flattening).
+ """
+ if torch.jit.is_scripting():
+ tensor, tensor_dim = _unsqueeze_dims_jit(tensor)
+ else:
+ tensor, tensor_dim = _unsqueeze_dims_fx(tensor)
+ B, N, L, C = tensor.shape
+ T, H, W = thw_shape
+ tensor = tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()
+
+ return tensor, (B, N, L, C, T, H, W, tensor_dim)
+
+
+def _post_attention_pool(
+ tensor: torch.Tensor,
+ thw_shape: List[int],
+) -> Tuple[torch.Tensor, List[int]]:
+
+ B, N, L, C, T, H, W, tensor_dim = thw_shape
+ thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
+ L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
+ tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
+ if torch.jit.is_scripting():
+ tensor = _squeeze_dims_jit(tensor, tensor_dim)
+ else:
+ tensor = _squeeze_dims_fx(tensor, tensor_dim)
+
+ return tensor, thw_shape
+
+
+class MultiScaleAttention(nn.Module):
+ """
+ Implementation of a multiscale attention block. Compare to a conventional attention
+ block, a multiscale attention block optionally supports pooling (either
+ before or after qkv projection). If pooling is not used, a multiscale attention
+ block is equivalent to a conventional attention block.
+
+ ::
+ Input
+ |
+ |----------------|-----------------|
+ ↓ ↓ ↓
+ Linear Linear Linear
+ & & &
+ Pool (Q) Pool (K) Pool (V)
+ → -------------- ← |
+ ↓ |
+ MatMul & Scale |
+ ↓ |
+ Softmax |
+ → ----------------------- ←
+ ↓
+ MatMul & Scale
+ ↓
+ DropOut
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ dropout_rate: float = 0.0,
+ kernel_q: _size_3_t = (1, 1, 1),
+ kernel_kv: _size_3_t = (1, 1, 1),
+ stride_q: _size_3_t = (1, 1, 1),
+ stride_kv: _size_3_t = (1, 1, 1),
+ norm_layer=nn.LayerNorm,
+ has_cls_embed: bool = True,
+ pool_mode: str = "conv",
+ pool_first: bool = False,
+ residual_pool: bool = True,
+ depthwise_conv: bool = True,
+ bias_on: bool = True,
+ separate_qkv: bool = True,
+ ) -> None:
+ """
+ Args:
+ dim (int): Input feature dimension.
+ num_heads (int): Number of heads in the attention layer.
+ qkv_bias (bool): If set to False, the qkv layer will not learn an additive
+ bias. Default: False.
+ dropout_rate (float): Dropout rate.
+ kernel_q (_size_3_t): Pooling kernel size for q. If both pooling kernel
+ size and pooling stride size are 1 for all the dimensions, pooling is
+ disabled.
+ kernel_kv (_size_3_t): Pooling kernel size for kv. If both pooling kernel
+ size and pooling stride size are 1 for all the dimensions, pooling is
+ disabled.
+ stride_q (_size_3_t): Pooling kernel stride for q.
+ stride_kv (_size_3_t): Pooling kernel stride for kv.
+ norm_layer (nn.Module): Normalization layer used after pooling.
+ has_cls_embed (bool): If set to True, the first token of the input tensor
+ should be a cls token. Otherwise, the input tensor does not contain a
+ cls token. Pooling is not applied to the cls token.
+ pool_mode (str): Pooling mode. Option includes "conv" (learned pooling), "avg"
+ (average pooling), and "max" (max pooling).
+ pool_first (bool): If set to True, pool is applied before qkv projection.
+ Otherwise, pool is applied after qkv projection. Default: False.
+ residual_pool (bool): If set to True, use Improved Multiscale Vision
+ Transformer's pooling residual connection.
+ depthwise_conv (bool): Whether use depthwise or full convolution for pooling.
+ bias_on (bool): Whether use biases for linear layers.
+ separate_qkv (bool): Whether to use separate or one layer for qkv projections.
+ """
+
+ super().__init__()
+ assert pool_mode in ["conv", "avg", "max"]
+ assert not pool_first
+ assert not has_cls_embed
+ assert not separate_qkv
+ self.dropout_rate = dropout_rate
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+ self.has_cls_embed = has_cls_embed
+ self.residual_pool = residual_pool
+ self.separate_qkv = separate_qkv
+ padding_q = [int(q // 2) for q in kernel_q]
+ padding_kv = [int(kv // 2) for kv in kernel_kv]
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim, bias=True if bias_on else False)
+
+ # Skip pooling with kernel and stride size of (1, 1, 1).
+ if (
+ kernel_q is not None
+ and numpy.prod(kernel_q) == 1
+ and numpy.prod(stride_q) == 1
+ ):
+ kernel_q = None
+ if (
+ kernel_kv is not None
+ and numpy.prod(kernel_kv) == 1
+ and numpy.prod(stride_kv) == 1
+ ):
+ kernel_kv = None
+
+ if pool_mode in ("avg", "max"):
+ pool_op = nn.MaxPool3d if pool_mode == "max" else nn.AvgPool3d
+ self.pool_q = (
+ pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
+ if kernel_q is not None
+ else None
+ )
+ self.pool_k = (
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
+ if kernel_kv is not None
+ else None
+ )
+ self.pool_v = (
+ pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
+ if kernel_kv is not None
+ else None
+ )
+ elif pool_mode == "conv":
+ self.pool_q = (
+ nn.Conv3d(
+ head_dim,
+ head_dim,
+ kernel_q,
+ stride=stride_q,
+ padding=padding_q,
+ groups=head_dim if depthwise_conv else 1,
+ bias=False,
+ )
+ if kernel_q is not None
+ else None
+ )
+ self.pool_k = (
+ nn.Conv3d(
+ head_dim,
+ head_dim,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=head_dim if depthwise_conv else 1,
+ bias=False,
+ )
+ if kernel_kv is not None
+ else None
+ )
+ self.pool_v = (
+ nn.Conv3d(
+ head_dim,
+ head_dim,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=head_dim if depthwise_conv else 1,
+ bias=False,
+ )
+ if kernel_kv is not None
+ else None
+ )
+ else:
+ raise NotImplementedError(f"Unsupported model {pool_mode}")
+
+ def _qkv_proj(
+ self,
+ q: torch.Tensor,
+ q_size: List[int],
+ k: torch.Tensor,
+ k_size: List[int],
+ v: torch.Tensor,
+ v_size: List[int],
+ batch_size: List[int],
+ chan_size: List[int],
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ q = (
+ self.q(q)
+ .reshape(batch_size, q_size, self.num_heads, chan_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ k = (
+ self.k(k)
+ .reshape(batch_size, k_size, self.num_heads, chan_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ v = (
+ self.v(v)
+ .reshape(batch_size, v_size, self.num_heads, chan_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ return q, k, v
+
+ def _qkv_pool(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ thw_shape: List[int],
+ ) -> Tuple[
+ torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int]
+ ]:
+ if self.pool_q is None:
+ q_shape = thw_shape
+ else:
+ q, q_shape = _pre_attention_pool(
+ q, [thw_shape[0], thw_shape[1], thw_shape[2]]
+ )
+ q = nn.functional.gelu(q)
+ q = self.pool_q(q)
+ q, q_shape = _post_attention_pool(
+ q,
+ q_shape,
+ )
+
+ if self.pool_k is None:
+ k_shape = thw_shape
+ else:
+ k, k_shape = _pre_attention_pool(
+ k,
+ [thw_shape[0], thw_shape[1], thw_shape[2]],
+ )
+ k = nn.functional.gelu(k)
+ k = self.pool_k(k)
+ k, k_shape = _post_attention_pool(
+ k,
+ k_shape,
+ )
+ if self.pool_v is None:
+ v_shape = thw_shape
+ else:
+ v, v_shape = _pre_attention_pool(
+ v,
+ [thw_shape[0], thw_shape[1], thw_shape[2]],
+ )
+ v = nn.functional.gelu(v)
+ v = self.pool_v(v)
+ v, v_shape = _post_attention_pool(
+ v,
+ v_shape,
+ )
+ return q, q_shape, k, k_shape, v, v_shape
+
+ def _get_qkv_length(
+ self,
+ q_shape: List[int],
+ k_shape: List[int],
+ v_shape: List[int],
+ ) -> Tuple[int]:
+ q_N = numpy.prod(q_shape) + 1 if self.has_cls_embed else numpy.prod(q_shape)
+ k_N = numpy.prod(k_shape) + 1 if self.has_cls_embed else numpy.prod(k_shape)
+ v_N = numpy.prod(v_shape) + 1 if self.has_cls_embed else numpy.prod(v_shape)
+ return q_N, k_N, v_N
+
+ def _reshape_qkv_to_seq(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ q_N: int,
+ v_N: int,
+ k_N: int,
+ B: int,
+ C: int,
+ ) -> Tuple[int]:
+ q = q.permute(0, 2, 1, 3).reshape(B, q_N, C)
+ v = v.permute(0, 2, 1, 3).reshape(B, v_N, C)
+ k = k.permute(0, 2, 1, 3).reshape(B, k_N, C)
+ return q, k, v
+
+ def forward(
+ self, x: torch.Tensor, thw_shape: List[int]
+ ) -> Tuple[torch.Tensor, List[int]]:
+ """
+ Args:
+ x (torch.Tensor): Input tensor.
+ thw_shape (List): The shape of the input tensor (before flattening).
+ """
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q, q_shape, k, k_shape, v, v_shape = self._qkv_pool(q, k, v, thw_shape)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+
+ N = q.shape[2]
+
+ if self.residual_pool:
+ x = (attn @ v + q).transpose(1, 2).reshape(B, N, C)
+ else:
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+
+ x = self.proj(x)
+ return x, q_shape
+
+
+class ScriptableMultiScaleBlock(nn.Module):
+ """
+ Implementation of a multiscale vision transformer block. Each block contains a
+ multiscale attention layer and a Mlp layer.
+
+ ::
+
+
+ Input
+ |-------------------+
+ ↓ |
+ Norm |
+ ↓ |
+ MultiScaleAttention Pool
+ ↓ |
+ DropPath |
+ ↓ |
+ Summation ←-------------+
+ |
+ |-------------------+
+ ↓ |
+ Norm |
+ ↓ |
+ Mlp Proj
+ ↓ |
+ DropPath |
+ ↓ |
+ Summation ←------------+
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ dropout_rate: float = 0.0,
+ droppath_rate: float = 0.0,
+ act_layer: nn.Module = nn.GELU,
+ norm_layer: nn.Module = nn.LayerNorm,
+ attn_norm_layer: nn.Module = nn.LayerNorm,
+ kernel_q: _size_3_t = (1, 1, 1),
+ kernel_kv: _size_3_t = (1, 1, 1),
+ stride_q: _size_3_t = (1, 1, 1),
+ stride_kv: _size_3_t = (1, 1, 1),
+ pool_mode: str = "conv",
+ has_cls_embed: bool = True,
+ pool_first: bool = False,
+ residual_pool: bool = False,
+ depthwise_conv: bool = True,
+ bias_on: bool = True,
+ separate_qkv: bool = True,
+ ) -> None:
+ """
+ Args:
+ dim (int): Input feature dimension.
+ dim_out (int): Output feature dimension.
+ num_heads (int): Number of heads in the attention layer.
+ mlp_ratio (float): Mlp ratio which controls the feature dimension in the
+ hidden layer of the Mlp block.
+ qkv_bias (bool): If set to False, the qkv layer will not learn an additive
+ bias. Default: False.
+ dropout_rate (float): DropOut rate. If set to 0, DropOut is disabled.
+ droppath_rate (float): DropPath rate. If set to 0, DropPath is disabled.
+ act_layer (nn.Module): Activation layer used in the Mlp layer.
+ norm_layer (nn.Module): Normalization layer.
+ attn_norm_layer (nn.Module): Normalization layer in the attention module.
+ kernel_q (_size_3_t): Pooling kernel size for q. If pooling kernel size is
+ 1 for all the dimensions, pooling is not used (by default).
+ kernel_kv (_size_3_t): Pooling kernel size for kv. If pooling kernel size
+ is 1 for all the dimensions, pooling is not used. By default, pooling
+ is disabled.
+ stride_q (_size_3_t): Pooling kernel stride for q.
+ stride_kv (_size_3_t): Pooling kernel stride for kv.
+ pool_mode (str): Pooling mode. Option includes "conv" (learned pooling), "avg"
+ (average pooling), and "max" (max pooling).
+ has_cls_embed (bool): If set to True, the first token of the input tensor
+ should be a cls token. Otherwise, the input tensor does not contain a
+ cls token. Pooling is not applied to the cls token.
+ pool_first (bool): If set to True, pool is applied before qkv projection.
+ Otherwise, pool is applied after qkv projection. Default: False.
+ residual_pool (bool): If set to True, use Improved Multiscale Vision
+ Transformer's pooling residual connection.
+ depthwise_conv (bool): Whether use depthwise or full convolution for pooling.
+ bias_on (bool): Whether use biases for linear layers.
+ separate_qkv (bool): Whether to use separate or one layer for qkv projections.
+ """
+ super().__init__()
+ assert not pool_first
+ assert not separate_qkv
+ self.dim = dim
+ self.dim_out = dim_out
+ kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
+ stride_skip = stride_q
+ padding_skip = [int(skip // 2) for skip in kernel_skip]
+ self.attn = MultiScaleAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ dropout_rate=dropout_rate,
+ kernel_q=kernel_q,
+ kernel_kv=kernel_kv,
+ stride_q=stride_q,
+ stride_kv=stride_kv,
+ norm_layer=attn_norm_layer,
+ has_cls_embed=has_cls_embed,
+ pool_mode=pool_mode,
+ pool_first=pool_first,
+ residual_pool=residual_pool,
+ bias_on=bias_on,
+ depthwise_conv=depthwise_conv,
+ separate_qkv=separate_qkv,
+ )
+ self.drop_path = (
+ DropPath(droppath_rate) if droppath_rate > 0.0 else nn.Identity()
+ )
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.has_cls_embed = has_cls_embed
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ out_features=dim_out,
+ act_layer=act_layer,
+ dropout_rate=dropout_rate,
+ bias_on=bias_on,
+ )
+ self.proj = (
+ nn.Linear(dim, dim_out, bias=bias_on) if dim != dim_out else nn.Identity()
+ )
+
+ self.pool_skip = (
+ nn.MaxPool3d(kernel_skip, stride_skip, padding_skip, ceil_mode=False)
+ if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1
+ else None
+ )
+
+ def forward(
+ self, x: torch.Tensor, thw_shape: List[int]
+ ) -> Tuple[torch.Tensor, List[int]]:
+ """
+ Args:
+ x (torch.Tensor): Input tensor.
+ thw_shape (List): The shape of the input tensor (before flattening).
+ """
+
+ x_block, thw_shape_new = self.attn(
+ x,
+ thw_shape,
+ )
+
+ if self.pool_skip is None:
+ x_res = x
+ else:
+ x_res, res_shape = _pre_attention_pool(
+ x, [thw_shape[0], thw_shape[1], thw_shape[2]]
+ )
+ x_res = self.pool_skip(x_res)
+ x_res, _ = _post_attention_pool(x_res, res_shape)
+
+ x = x_res + self.drop_path(x_block)
+ x_norm = x
+ x_mlp = self.mlp(x_norm)
+ if self.dim != self.dim_out:
+ x = self.proj(x_norm)
+ x = x + self.drop_path(x_mlp)
+ return x, thw_shape_new
diff --git a/code/pytorchvideo/pytorchvideo/layers/batch_norm.py b/code/pytorchvideo/pytorchvideo/layers/batch_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9687e4bae16c5bd8eaa08cadf1c5b64a24c7a82b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/batch_norm.py
@@ -0,0 +1,229 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import pytorchvideo.layers.distributed as du
+import torch
+import torch.distributed as dist
+from fvcore.nn.distributed import differentiable_all_reduce
+from torch import nn
+
+
+class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
+ """
+ An implementation of 1D naive sync batch normalization. See details in
+ NaiveSyncBatchNorm2d below.
+
+ Args:
+ num_sync_devices (int): number of (local) devices to sync.
+ global_sync (bool): sync across all devices (on all machines).
+ args (list): other arguments.
+ """
+
+ def __init__(self, num_sync_devices=None, global_sync=True, **args):
+
+ self.global_sync = global_sync
+ if self.global_sync and num_sync_devices is not None:
+ raise ValueError(
+ f"Cannot set num_sync_devices separately when global_sync = {self.global_sync}"
+ )
+ if not self.global_sync and num_sync_devices is None:
+ raise ValueError(
+ f"num_sync_devices cannot be None when global_sync = {self.global_sync}"
+ )
+
+ if not self.global_sync:
+ self.num_sync_devices = num_sync_devices
+ if self.num_sync_devices > 0:
+ assert du.get_local_size() % self.num_sync_devices == 0, (
+ du.get_local_size(),
+ self.num_sync_devices,
+ )
+ self.num_groups = du.get_local_size() // self.num_sync_devices
+ else:
+ self.num_sync_devices = du.get_local_size()
+ self.num_groups = 1
+ super(NaiveSyncBatchNorm1d, self).__init__(**args)
+
+ def forward(self, input):
+ if du.get_world_size() == 1 or not self.training:
+ return super().forward(input)
+
+ B, C = input.shape[0], input.shape[1]
+
+ assert B > 0, "SyncBatchNorm does not support zero batch size."
+
+ mean = torch.mean(input, dim=[0])
+ meansqr = torch.mean(input * input, dim=[0])
+
+ vec = torch.cat([mean, meansqr], dim=0)
+ # sync stats globally or locally
+ if self.global_sync:
+ vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
+ else:
+ vec = du.GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * (
+ 1.0 / self.num_sync_devices
+ )
+
+ mean, meansqr = torch.split(vec, C)
+ var = meansqr - mean * mean
+
+ invstd = torch.rsqrt(var + self.eps)
+ scale = self.weight * invstd
+ bias = self.bias - mean * scale
+ scale = scale.reshape(1, -1)
+ bias = bias.reshape(1, -1)
+
+ self.running_mean += self.momentum * (mean.detach() - self.running_mean)
+ self.running_var += self.momentum * (var.detach() - self.running_var)
+
+ return input * scale + bias
+
+
+class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
+ """
+ An implementation of 2D naive sync batch normalization.
+ In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
+ when the batch size on each worker is different.
+ (e.g., when scale augmentation is used, or when it is applied to mask head).
+
+ This is a slower but correct alternative to `nn.SyncBatchNorm`.
+
+ Args:
+ num_sync_devices (int): number of (local) devices to sync.
+ global_sync (bool): sync across all devices (on all machines).
+ args (list): other arguments.
+
+ Note:
+ This module computes overall statistics by using
+ statistics of each worker with equal weight. The result is true statistics
+ of all samples (as if they are all on one worker) only when all workers
+ have the same (N, H, W). This mode does not support inputs with zero batch size.
+ """
+
+ def __init__(self, num_sync_devices=None, global_sync=True, **args):
+
+ self.global_sync = global_sync
+ if self.global_sync and num_sync_devices is not None:
+ raise ValueError(
+ f"Cannot set num_sync_devices separately when global_sync = {self.global_sync}"
+ )
+ if not self.global_sync and num_sync_devices is None:
+ raise ValueError(
+ f"num_sync_devices cannot be None when global_sync = {self.global_sync}"
+ )
+
+ if not self.global_sync:
+ self.num_sync_devices = num_sync_devices
+ if self.num_sync_devices > 0:
+ assert du.get_local_size() % self.num_sync_devices == 0, (
+ du.get_local_size(),
+ self.num_sync_devices,
+ )
+ self.num_groups = du.get_local_size() // self.num_sync_devices
+ else:
+ self.num_sync_devices = du.get_local_size()
+ self.num_groups = 1
+ super(NaiveSyncBatchNorm2d, self).__init__(**args)
+
+ def forward(self, input):
+ if du.get_world_size() == 1 or not self.training:
+ return super().forward(input)
+
+ B, C = input.shape[0], input.shape[1]
+
+ assert B > 0, "SyncBatchNorm does not support zero batch size."
+
+ mean = torch.mean(input, dim=[0, 2, 3])
+ meansqr = torch.mean(input * input, dim=[0, 2, 3])
+
+ vec = torch.cat([mean, meansqr], dim=0)
+ # sync stats globally or locally
+ if self.global_sync:
+ vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
+ else:
+ vec = du.GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * (
+ 1.0 / self.num_sync_devices
+ )
+
+ mean, meansqr = torch.split(vec, C)
+ var = meansqr - mean * mean
+
+ invstd = torch.rsqrt(var + self.eps)
+ scale = self.weight * invstd
+ bias = self.bias - mean * scale
+ scale = scale.reshape(1, -1, 1, 1)
+ bias = bias.reshape(1, -1, 1, 1)
+
+ self.running_mean += self.momentum * (mean.detach() - self.running_mean)
+ self.running_var += self.momentum * (var.detach() - self.running_var)
+
+ return input * scale + bias
+
+
+class NaiveSyncBatchNorm3d(nn.BatchNorm3d):
+ """
+ Naive version of Synchronized 3D BatchNorm. See details in
+ NaiveSyncBatchNorm2d above.
+ Args:
+ num_sync_devices (int): number of (local) devices to sync.
+ global_sync (bool): sync across all devices (on all machines).
+ args (list): other arguments.
+ """
+
+ def __init__(self, num_sync_devices=None, global_sync=True, **args):
+
+ self.global_sync = global_sync
+ if self.global_sync and num_sync_devices is not None:
+ raise ValueError(
+ f"Cannot set num_sync_devices separately when global_sync = {self.global_sync}"
+ )
+ if not self.global_sync and num_sync_devices is None:
+ raise ValueError(
+ f"num_sync_devices cannot be None when global_sync = {self.global_sync}"
+ )
+
+ if not self.global_sync:
+ self.num_sync_devices = num_sync_devices
+ if self.num_sync_devices > 0:
+ assert du.get_local_size() % self.num_sync_devices == 0, (
+ du.get_local_size(),
+ self.num_sync_devices,
+ )
+ self.num_groups = du.get_local_size() // self.num_sync_devices
+ else:
+ self.num_sync_devices = du.get_local_size()
+ self.num_groups = 1
+ super(NaiveSyncBatchNorm3d, self).__init__(**args)
+
+ def forward(self, input):
+ if du.get_world_size() == 1 or not self.training:
+ return super().forward(input)
+
+ B, C = input.shape[0], input.shape[1]
+
+ assert B > 0, "SyncBatchNorm does not support zero batch size."
+
+ mean = torch.mean(input, dim=[0, 2, 3, 4])
+ meansqr = torch.mean(input * input, dim=[0, 2, 3, 4])
+
+ vec = torch.cat([mean, meansqr], dim=0)
+ # sync stats globally or locally
+ if self.global_sync:
+ vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
+ else:
+ vec = du.GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * (
+ 1.0 / self.num_sync_devices
+ )
+
+ mean, meansqr = torch.split(vec, C)
+ var = meansqr - mean * mean
+
+ invstd = torch.rsqrt(var + self.eps)
+ scale = self.weight * invstd
+ bias = self.bias - mean * scale
+ scale = scale.reshape(1, -1, 1, 1, 1)
+ bias = bias.reshape(1, -1, 1, 1, 1)
+
+ self.running_mean += self.momentum * (mean.detach() - self.running_mean)
+ self.running_var += self.momentum * (var.detach() - self.running_var)
+
+ return input * scale + bias
diff --git a/code/pytorchvideo/pytorchvideo/layers/convolutions.py b/code/pytorchvideo/pytorchvideo/layers/convolutions.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d2ebc9b9f78694a33df22f9db436a532dbf1e5
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/convolutions.py
@@ -0,0 +1,237 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.utils import set_attributes
+from torch.nn.common_types import _size_3_t
+
+
+class ConvReduce3D(nn.Module):
+ """
+ Builds a list of convolutional operators and performs summation on the outputs.
+
+ ::
+
+ Conv3d, Conv3d, ..., Conv3d
+ ↓
+ Sum
+ """
+
+ def __init__(
+ self,
+ *,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Tuple[_size_3_t],
+ stride: Optional[Tuple[_size_3_t]] = None,
+ padding: Optional[Tuple[_size_3_t]] = None,
+ padding_mode: Optional[Tuple[str]] = None,
+ dilation: Optional[Tuple[_size_3_t]] = None,
+ groups: Optional[Tuple[int]] = None,
+ bias: Optional[Tuple[bool]] = None,
+ reduction_method: str = "sum",
+ ) -> None:
+ """
+ Args:
+ in_channels int: number of input channels.
+ out_channels int: number of output channels produced by the convolution(s).
+ kernel_size tuple(_size_3_t): Tuple of sizes of the convolutionaling kernels.
+ stride tuple(_size_3_t): Tuple of strides of the convolutions.
+ padding tuple(_size_3_t): Tuple of paddings added to all three sides of the
+ input.
+ padding_mode tuple(string): Tuple of padding modes for each convs.
+ Options include `zeros`, `reflect`, `replicate` or `circular`.
+ dilation tuple(_size_3_t): Tuple of spacings between kernel elements.
+ groups tuple(_size_3_t): Tuple of numbers of blocked connections from input
+ channels to output channels.
+ bias tuple(bool): If `True`, adds a learnable bias to the output.
+ reduction_method str: Options include `sum` and `cat`.
+ """
+ super().__init__()
+ assert reduction_method in ("sum", "cat")
+ self.reduction_method = reduction_method
+ conv_list = []
+ for ind in range(len(kernel_size)):
+ conv_param = {
+ "in_channels": in_channels,
+ "out_channels": out_channels,
+ "kernel_size": kernel_size[ind],
+ }
+ if stride is not None and stride[ind] is not None:
+ conv_param["stride"] = stride[ind]
+ if padding is not None and padding[ind] is not None:
+ conv_param["padding"] = padding[ind]
+ if dilation is not None and dilation[ind] is not None:
+ conv_param["dilation"] = dilation[ind]
+ if groups is not None and groups[ind] is not None:
+ conv_param["groups"] = groups[ind]
+ if bias is not None and bias[ind] is not None:
+ conv_param["bias"] = bias[ind]
+ if padding_mode is not None and padding_mode[ind] is not None:
+ conv_param["padding_mode"] = padding_mode[ind]
+ conv_list.append(nn.Conv3d(**conv_param))
+ self.convs = nn.ModuleList(conv_list)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ output = []
+ for ind in range(len(self.convs)):
+ output.append(self.convs[ind](x))
+ if self.reduction_method == "sum":
+ output = torch.stack(output, dim=0).sum(dim=0, keepdim=False)
+ elif self.reduction_method == "cat":
+ output = torch.cat(output, dim=1)
+ return output
+
+
+def create_conv_2plus1d(
+ *,
+ # Conv configs.
+ in_channels: int,
+ out_channels: int,
+ inner_channels: int = None,
+ conv_xy_first: bool = False,
+ kernel_size: Tuple[int] = (3, 3, 3),
+ stride: Tuple[int] = (2, 2, 2),
+ padding: Tuple[int] = (1, 1, 1),
+ bias: bool = False,
+ dilation: Tuple[int] = (1, 1, 1),
+ groups: int = 1,
+ # BN configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Create a 2plus1d conv layer. It performs spatiotemporal Convolution, BN, and
+ Relu following by a spatiotemporal pooling.
+
+ ::
+
+ Conv_t (or Conv_xy if conv_xy_first = True)
+ ↓
+ Normalization
+ ↓
+ Activation
+ ↓
+ Conv_xy (or Conv_t if conv_xy_first = True)
+
+ Normalization options include: BatchNorm3d and None (no normalization).
+ Activation options include: ReLU, Softmax, Sigmoid, and None (no activation).
+
+ Args:
+ in_channels (int): input channel size of the convolution.
+ out_channels (int): output channel size of the convolution.
+ kernel_size (tuple): convolutional kernel size(s).
+ stride (tuple): convolutional stride size(s).
+ padding (tuple): convolutional padding size(s).
+ bias (bool): convolutional bias. If true, adds a learnable bias to the
+ output.
+ groups (int): Number of groups in convolution layers. value >1 is unsupported.
+ dilation (tuple): dilation value in convolution layers. value >1 is unsupported.
+ conv_xy_first (bool): If True, spatial convolution comes before temporal conv
+
+ norm (callable): a callable that constructs normalization layer, options
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer, options
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): 2plus1d conv layer.
+ """
+ if inner_channels is None:
+ inner_channels = out_channels
+
+ assert (
+ groups == 1
+ ), "Support for groups is not implemented in R2+1 convolution layer"
+ assert (
+ max(dilation) == 1 and min(dilation) == 1
+ ), "Support for dillaiton is not implemented in R2+1 convolution layer"
+
+ conv_t_module = nn.Conv3d(
+ in_channels=in_channels if not conv_xy_first else inner_channels,
+ out_channels=inner_channels if not conv_xy_first else out_channels,
+ kernel_size=(kernel_size[0], 1, 1),
+ stride=(stride[0], 1, 1),
+ padding=(padding[0], 0, 0),
+ bias=bias,
+ )
+ norm_module = (
+ None
+ if norm is None
+ else norm(num_features=inner_channels, eps=norm_eps, momentum=norm_momentum)
+ )
+ activation_module = None if activation is None else activation()
+ conv_xy_module = nn.Conv3d(
+ in_channels=inner_channels if not conv_xy_first else in_channels,
+ out_channels=out_channels if not conv_xy_first else inner_channels,
+ kernel_size=(1, kernel_size[1], kernel_size[2]),
+ stride=(1, stride[1], stride[2]),
+ padding=(0, padding[1], padding[2]),
+ bias=bias,
+ )
+
+ return Conv2plus1d(
+ conv_t=conv_t_module,
+ norm=norm_module,
+ activation=activation_module,
+ conv_xy=conv_xy_module,
+ conv_xy_first=conv_xy_first,
+ )
+
+
+class Conv2plus1d(nn.Module):
+ """
+ Implementation of 2+1d Convolution by factorizing 3D Convolution into an 1D temporal
+ Convolution and a 2D spatial Convolution with Normalization and Activation module
+ in between:
+
+ ::
+
+ Conv_t (or Conv_xy if conv_xy_first = True)
+ ↓
+ Normalization
+ ↓
+ Activation
+ ↓
+ Conv_xy (or Conv_t if conv_xy_first = True)
+
+ The 2+1d Convolution is used to build the R(2+1)D network.
+ """
+
+ def __init__(
+ self,
+ *,
+ conv_t: nn.Module = None,
+ norm: nn.Module = None,
+ activation: nn.Module = None,
+ conv_xy: nn.Module = None,
+ conv_xy_first: bool = False,
+ ) -> None:
+ """
+ Args:
+ conv_t (torch.nn.modules): temporal convolution module.
+ norm (torch.nn.modules): normalization module.
+ activation (torch.nn.modules): activation module.
+ conv_xy (torch.nn.modules): spatial convolution module.
+ conv_xy_first (bool): If True, spatial convolution comes before temporal conv
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.conv_t is not None
+ assert self.conv_xy is not None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv_xy(x) if self.conv_xy_first else self.conv_t(x)
+ x = self.norm(x) if self.norm else x
+ x = self.activation(x) if self.activation else x
+ x = self.conv_t(x) if self.conv_xy_first else self.conv_xy(x)
+ return x
diff --git a/code/pytorchvideo/pytorchvideo/layers/distributed.py b/code/pytorchvideo/pytorchvideo/layers/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e309b387bf17331bf6718319d98b5437aa6969e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/distributed.py
@@ -0,0 +1,146 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""Distributed helpers."""
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_c10d import ProcessGroup
+from torch.autograd.function import Function
+
+_LOCAL_PROCESS_GROUP = None
+
+
+def get_world_size() -> int:
+ """
+ Simple wrapper for correctly getting worldsize in both distributed
+ / non-distributed settings
+ """
+ return (
+ torch.distributed.get_world_size()
+ if torch.distributed.is_available() and torch.distributed.is_initialized()
+ else 1
+ )
+
+
+def cat_all_gather(tensors, local=False):
+ """Performs the concatenated all_reduce operation on the provided tensors."""
+ if local:
+ gather_sz = get_local_size()
+ else:
+ gather_sz = torch.distributed.get_world_size()
+ tensors_gather = [torch.ones_like(tensors) for _ in range(gather_sz)]
+ torch.distributed.all_gather(
+ tensors_gather,
+ tensors,
+ async_op=False,
+ group=_LOCAL_PROCESS_GROUP if local else None,
+ )
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+def init_distributed_training(num_gpus, shard_id):
+ """
+ Initialize variables needed for distributed training.
+ """
+ if num_gpus <= 1:
+ return
+ num_gpus_per_machine = num_gpus
+ num_machines = dist.get_world_size() // num_gpus_per_machine
+ for i in range(num_machines):
+ ranks_on_i = list(
+ range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
+ )
+ pg = dist.new_group(ranks_on_i)
+ if i == shard_id:
+ global _LOCAL_PROCESS_GROUP
+ _LOCAL_PROCESS_GROUP = pg
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_process_group() -> ProcessGroup:
+ assert _LOCAL_PROCESS_GROUP is not None
+ return _LOCAL_PROCESS_GROUP
+
+
+class GroupGather(Function):
+ """
+ GroupGather performs all gather on each of the local process/ GPU groups.
+ """
+
+ @staticmethod
+ def forward(ctx, input, num_sync_devices, num_groups):
+ """
+ Perform forwarding, gathering the stats across different process/ GPU
+ group.
+ """
+ ctx.num_sync_devices = num_sync_devices
+ ctx.num_groups = num_groups
+
+ input_list = [torch.zeros_like(input) for k in range(get_local_size())]
+ dist.all_gather(
+ input_list, input, async_op=False, group=get_local_process_group()
+ )
+
+ inputs = torch.stack(input_list, dim=0)
+ if num_groups > 1:
+ rank = get_local_rank()
+ group_idx = rank // num_sync_devices
+ inputs = inputs[
+ group_idx * num_sync_devices : (group_idx + 1) * num_sync_devices
+ ]
+ inputs = torch.sum(inputs, dim=0)
+ return inputs
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ Perform backwarding, gathering the gradients across different process/ GPU
+ group.
+ """
+ grad_output_list = [
+ torch.zeros_like(grad_output) for k in range(get_local_size())
+ ]
+ dist.all_gather(
+ grad_output_list,
+ grad_output,
+ async_op=False,
+ group=get_local_process_group(),
+ )
+
+ grads = torch.stack(grad_output_list, dim=0)
+ if ctx.num_groups > 1:
+ rank = get_local_rank()
+ group_idx = rank // ctx.num_sync_devices
+ grads = grads[
+ group_idx
+ * ctx.num_sync_devices : (group_idx + 1)
+ * ctx.num_sync_devices
+ ]
+ grads = torch.sum(grads, dim=0)
+ return grads, None, None
diff --git a/code/pytorchvideo/pytorchvideo/layers/drop_path.py b/code/pytorchvideo/pytorchvideo/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..8023e10b8721a3df08bb623113dc6ba977cd51a8
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/drop_path.py
@@ -0,0 +1,48 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import torch
+import torch.nn as nn
+
+
+def drop_path(
+ x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
+) -> torch.Tensor:
+ """
+ Stochastic Depth per sample.
+
+ Args:
+ x (tensor): Input tensor.
+ drop_prob (float): Probability to apply drop path.
+ training (bool): If True, apply drop path to input. Otherwise (tesing), return input.
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ mask.floor_() # binarize
+ output = x.div(keep_prob) * mask
+ return output
+
+
+class DropPath(nn.Module):
+ """
+ Drop paths (Stochastic Depth) per sample.
+ """
+
+ def __init__(self, drop_prob: float = 0.0) -> None:
+ """
+ Args:
+ drop_prob (float): Probability to apply drop path.
+ """
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (tensor): Input tensor.
+ """
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/code/pytorchvideo/pytorchvideo/layers/fusion.py b/code/pytorchvideo/pytorchvideo/layers/fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9656bec4dab3320ae8ae79f0c49ba509bfdb5ca0
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/fusion.py
@@ -0,0 +1,149 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, List
+
+import torch
+import torch.nn as nn
+
+
+"""
+Fusion layers are nn.Modules that take a list of Tensors (e.g. from a multi-stream
+architecture), and return a single fused Tensor. This file has several
+different types of fusion layers and a factory function "make_fusion_layer" to
+construct them.
+"""
+
+
+def make_fusion_layer(method: str, feature_dims: List[int]):
+ """
+ Args:
+ method (str): the fusion method to be constructed. Options:
+ - 'concat'
+ - 'temporal_concat'
+ - 'max'
+ - 'sum'
+ - 'prod'
+
+ feature_dims (List[int]): the first argument of all fusion layers. It holds a list
+ of required feature_dims for each tensor input (where the tensor inputs are of
+ shape (batch_size, seq_len, feature_dim)). The list order must corresponds to
+ the tensor order passed to forward(...).
+ """
+ if method == "concat":
+ return ConcatFusion(feature_dims)
+ elif method == "temporal_concat":
+ return TemporalConcatFusion(feature_dims)
+ elif method == "max":
+ return ReduceFusion(feature_dims, lambda x: torch.max(x, dim=0).values)
+ elif method == "sum":
+ return ReduceFusion(feature_dims, lambda x: torch.sum(x, dim=0))
+ elif method == "prod":
+ return ReduceFusion(feature_dims, lambda x: torch.prod(x, dim=0))
+ else:
+ raise NotImplementedError(f"Fusion {method} not available.")
+
+
+class ConcatFusion(nn.Module):
+ """
+ Concatenates all inputs by their last dimension. The resulting tensor last dim will be
+ the sum of the last dimension of all input tensors.
+ """
+
+ def __init__(self, feature_dims: List[int]):
+ super().__init__()
+ _verify_feature_dim(feature_dims)
+ self._output_dim = sum(feature_dims)
+
+ @property
+ def output_dim(self):
+ """
+ Last dimension size of forward(..) tensor output.
+ """
+ return self._output_dim
+
+ def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Args:
+ input_list (List[torch.Tensor]): a list of tensors of shape
+ (batch_size, seq_len, feature_dim).
+
+ Returns:
+ Tensor of shape (batch_size, seq_len, sum(feature_dims)) where sum(feature_dims)
+ is the sum of all input feature_dims.
+ """
+ return torch.cat(input_list, dim=-1)
+
+
+class TemporalConcatFusion(nn.Module):
+ """
+ Concatenates all inputs by their temporal dimension which is assumed to be dim=1.
+ """
+
+ def __init__(self, feature_dims: List[int]):
+ super().__init__()
+ _verify_feature_dim(feature_dims)
+
+ # All input dimensions must be the same
+ self._output_dim = max(feature_dims)
+ assert self._output_dim == min(feature_dims)
+
+ @property
+ def output_dim(self):
+ """
+ Last dimension size of forward(..) tensor output.
+ """
+ return self._output_dim
+
+ def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Args:
+ input_list (List[torch.Tensor]): a list of tensors of shape
+ (batch_size, seq_len, feature_dim)
+
+ Returns:
+ Tensor of shape (batch_size, sum(seq_len), feature_dim) where sum(seq_len) is
+ the sum of all input tensors.
+ """
+ return torch.cat(input_list, dim=1)
+
+
+class ReduceFusion(nn.Module):
+ """
+ Generic fusion method which takes a callable which takes the list of input tensors
+ and expects a single tensor to be used. This class can be used to implement fusion
+ methods like "sum", "max" and "prod".
+ """
+
+ def __init__(
+ self, feature_dims: List[int], reduce_fn: Callable[[torch.Tensor], torch.Tensor]
+ ):
+ super().__init__()
+ _verify_feature_dim(feature_dims)
+ self.reduce_fn = reduce_fn
+
+ # All input dimensions must be the same
+ self._output_dim = max(feature_dims)
+ assert self._output_dim == min(feature_dims)
+
+ @property
+ def output_dim(self):
+ """
+ Last dimension size of forward(..) tensor output.
+ """
+ return self._output_dim
+
+ def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Args:
+ input_list (List[torch.Tensor]): a list of tensors of shape
+ (batch_size, seq_len, feature_dim).
+
+ Returns:
+ Tensor of shape (batch_size, seq_len, feature_dim).
+ """
+ return self.reduce_fn(torch.stack(input_list))
+
+
+def _verify_feature_dim(feature_dims: List[int]):
+ assert isinstance(feature_dims, list)
+ assert all(x > 0 for x in feature_dims)
diff --git a/code/pytorchvideo/pytorchvideo/layers/mlp.py b/code/pytorchvideo/pytorchvideo/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..78556e77f78408da7d8f7c751d4561bd83e85077
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/mlp.py
@@ -0,0 +1,62 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+from typing import Callable, List, Optional, Tuple
+
+from torch import nn
+
+
+def make_multilayer_perceptron(
+ fully_connected_dims: List[int],
+ norm: Optional[Callable] = None,
+ mid_activation: Callable = nn.ReLU,
+ final_activation: Optional[Callable] = nn.ReLU,
+ dropout_rate: float = 0.0,
+) -> Tuple[nn.Module, int]:
+ """
+ Factory function for Multi-Layer Perceptron. These are constructed as repeated
+ blocks of the following format where each fc represents the blocks output/input dimension.
+
+ ::
+
+ Linear (in=fc[i-1], out=fc[i])
+ ↓
+ Normalization (norm)
+ ↓
+ Activation (mid_activation)
+ ↓
+ After the repeated Perceptron blocks,
+ a final dropout and activation layer is applied:
+ ↓
+ Dropout (p=dropout_rate)
+ ↓
+ Activation (final_activation)
+
+ """
+ assert isinstance(fully_connected_dims, list)
+ assert len(fully_connected_dims) > 1
+ assert all(_is_pos_int(x) for x in fully_connected_dims)
+
+ layers = []
+ cur_dim = fully_connected_dims[0]
+ for dim in fully_connected_dims[1:-1]:
+ layers.append(nn.Linear(cur_dim, dim))
+ if norm is not None:
+ layers.append(norm(dim))
+ layers.append(mid_activation())
+ cur_dim = dim
+ layers.append(nn.Linear(cur_dim, fully_connected_dims[-1]))
+ if dropout_rate > 0:
+ layers.append(nn.Dropout(p=dropout_rate))
+ if final_activation is not None:
+ layers.append(final_activation())
+
+ mlp = nn.Sequential(*layers)
+ output_dim = fully_connected_dims[-1]
+ return mlp, output_dim
+
+
+def _is_pos_int(number: int) -> bool:
+ """
+ Returns True if a number is a positive integer.
+ """
+ return type(number) == int and number >= 0
diff --git a/code/pytorchvideo/pytorchvideo/layers/nonlocal_net.py b/code/pytorchvideo/pytorchvideo/layers/nonlocal_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6ae91f535adcf4914dfb13092e5b5d64754cd7b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/nonlocal_net.py
@@ -0,0 +1,153 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Iterable, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.utils import set_attributes
+
+
+class NonLocal(nn.Module):
+ """
+ Builds Non-local Neural Networks as a generic family of building
+ blocks for capturing long-range dependencies. Non-local Network
+ computes the response at a position as a weighted sum of the
+ features at all positions. This building block can be plugged into
+ many computer vision architectures.
+ More details in the paper:
+ Wang, Xiaolong, Ross Girshick, Abhinav Gupta, and Kaiming He.
+ "Non-local neural networks."
+ In Proceedings of the IEEE conference on CVPR, 2018.
+ """
+
+ def __init__(
+ self,
+ *,
+ conv_theta: nn.Module,
+ conv_phi: nn.Module,
+ conv_g: nn.Module,
+ conv_out: nn.Module,
+ pool: Optional[nn.Module] = None,
+ norm: Optional[nn.Module] = None,
+ instantiation: str = "dot_product",
+ ) -> None:
+ super().__init__()
+ set_attributes(self, locals())
+ assert None not in (conv_theta, conv_phi, conv_g, conv_out)
+ assert instantiation in (
+ "dot_product",
+ "softmax",
+ ), "Unknown norm type {}".format(instantiation)
+ assert (
+ len(
+ {
+ self.conv_theta.out_channels,
+ self.conv_phi.out_channels,
+ self.conv_g.out_channels,
+ self.conv_out.in_channels,
+ }
+ )
+ == 1
+ ), "Nonlocal convolution's input/ output dimension mismatch."
+
+ def forward(self, x) -> torch.Tensor:
+ dim_inner = self.conv_theta.out_channels
+
+ x_identity = x
+ N, C, T, H, W = x.size()
+
+ theta = self.conv_theta(x)
+ # Perform temporal-spatial pooling to reduce the computation.
+ if self.pool is not None:
+ x = self.pool(x)
+
+ phi = self.conv_phi(x)
+ g = self.conv_g(x)
+
+ theta = theta.view(N, dim_inner, -1)
+ phi = phi.view(N, dim_inner, -1)
+ g = g.view(N, dim_inner, -1)
+
+ # (N, C, TxHxW) x (N, C, TxHxW) => (N, TxHxW, TxHxW).
+ theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
+ # For original Non-local paper, there are two main ways to normalize
+ # the affinity tensor:
+ # 1) Softmax normalization (norm on exp).
+ # 2) dot_product normalization.
+ if self.instantiation == "softmax":
+ # Normalizing the affinity tensor theta_phi before softmax.
+ theta_phi = theta_phi * (dim_inner**-0.5)
+ theta_phi = nn.functional.softmax(theta_phi, dim=2)
+ elif self.instantiation == "dot_product":
+ spatial_temporal_dim = theta_phi.shape[2]
+ theta_phi = theta_phi / spatial_temporal_dim
+
+ # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
+ theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
+ # (N, C, TxHxW) => (N, C, T, H, W).
+ theta_phi_g = theta_phi_g.view(N, dim_inner, T, H, W)
+ p = self.conv_out(theta_phi_g)
+ if self.norm is not None:
+ p = self.norm(p)
+ return x_identity + p
+
+
+def create_nonlocal(
+ *,
+ # Nonlocal configs.
+ dim_in: int,
+ dim_inner: int,
+ pool_size: Optional[Tuple[int]] = (1, 1, 1),
+ instantiation: str = "softmax",
+ # Norm configs.
+ norm: Optional[Callable] = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+):
+ """
+ Builds Non-local Neural Networks as a generic family of building
+ blocks for capturing long-range dependencies. Non-local Network
+ computes the response at a position as a weighted sum of the
+ features at all positions. This building block can be plugged into
+ many computer vision architectures.
+ More details in the paper: https://arxiv.org/pdf/1711.07971
+ Args:
+ dim_in (int): number of dimension for the input.
+ dim_inner (int): number of dimension inside of the Non-local block.
+ pool_size (tuple[int]): the kernel size of spatial temporal pooling,
+ temporal pool kernel size, spatial pool kernel size, spatial pool kernel
+ size in order. By default pool_size is None, then there would be no pooling
+ used.
+ instantiation (string): supports two different instantiation method:
+ "dot_product": normalizing correlation matrix with L2.
+ "softmax": normalizing correlation matrix with Softmax.
+ norm (nn.Module): nn.Module for the normalization layer. The default is
+ nn.BatchNorm3d.
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+ """
+ if pool_size is None:
+ pool_size = (1, 1, 1)
+ assert isinstance(pool_size, Iterable)
+
+ if norm is None:
+ norm_model = None
+ else:
+ norm_model = norm(num_features=dim_in, eps=norm_eps, momentum=norm_momentum)
+
+ if any(size > 1 for size in pool_size):
+ pool_model = nn.MaxPool3d(
+ kernel_size=pool_size, stride=pool_size, padding=[0, 0, 0]
+ )
+ else:
+ pool_model = None
+
+ return NonLocal(
+ conv_theta=nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0),
+ conv_phi=nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0),
+ conv_g=nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0),
+ conv_out=nn.Conv3d(dim_inner, dim_in, kernel_size=1, stride=1, padding=0),
+ pool=pool_model,
+ norm=norm_model,
+ instantiation=instantiation,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/layers/positional_encoding.py b/code/pytorchvideo/pytorchvideo/layers/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f67e9e381e1825cefd6c2a313f8e98d12b627ae
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/positional_encoding.py
@@ -0,0 +1,244 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+import math
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import nn
+
+
+class PositionalEncoding(nn.Module):
+ """
+ Applies a positional encoding to a tensor with shape (batch_size x seq_len x embed_dim).
+
+ The positional encoding is computed as follows:
+ PE(pos,2i) = sin(pos/10000^(2i/dmodel))
+ PE(pos,2i+1) = cos(pos/10000^(2i/dmodel))
+
+ where pos = position, pos in [0, seq_len)
+ dmodel = data embedding dimension = embed_dim
+ i = dimension index, i in [0, embed_dim)
+
+ Reference: "Attention Is All You Need" https://arxiv.org/abs/1706.03762
+ Implementation Reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
+ """
+
+ def __init__(self, embed_dim: int, seq_len: int = 1024) -> None:
+ super().__init__()
+ pe = torch.zeros(seq_len, embed_dim, dtype=torch.float)
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, embed_dim, 2).float() * (-(math.log(10000.0)) / embed_dim)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ assert self.pe.size(1) >= x.size(1), (
+ "Cannot apply position encoding of size "
+ + f"{self.pe.size()} when input has size {x.size()}"
+ )
+ return x + self.pe[:, : x.size(1), :]
+
+
+class SpatioTemporalClsPositionalEncoding(nn.Module):
+ """
+ Add a cls token and apply a spatiotemporal encoding to a tensor.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ patch_embed_shape: Tuple[int, int, int],
+ sep_pos_embed: bool = False,
+ has_cls: bool = True,
+ ) -> None:
+ """
+ Args:
+ embed_dim (int): Embedding dimension for input sequence.
+ patch_embed_shape (Tuple): The number of patches in each dimension
+ (T, H, W) after patch embedding.
+ sep_pos_embed (bool): If set to true, one positional encoding is used for
+ spatial patches and another positional encoding is used for temporal
+ sequence. Otherwise, only one positional encoding is used for all the
+ patches.
+ has_cls (bool): If set to true, a cls token is added in the beginning of each
+ input sequence.
+ """
+ super().__init__()
+ assert (
+ len(patch_embed_shape) == 3
+ ), "Patch_embed_shape should be in the form of (T, H, W)."
+ self.cls_embed_on = has_cls
+ self.sep_pos_embed = sep_pos_embed
+ self._patch_embed_shape = tuple(patch_embed_shape)
+ self.num_spatial_patch = patch_embed_shape[1] * patch_embed_shape[2]
+ self.num_temporal_patch = patch_embed_shape[0]
+
+ if self.cls_embed_on:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ num_patches = self.num_spatial_patch * self.num_temporal_patch + 1
+ else:
+ self.cls_token = torch.tensor(0)
+ num_patches = self.num_spatial_patch * self.num_temporal_patch
+
+ if self.sep_pos_embed:
+ self.pos_embed_spatial = nn.Parameter(
+ torch.zeros(1, self.num_spatial_patch, embed_dim)
+ )
+ self.pos_embed_temporal = nn.Parameter(
+ torch.zeros(1, self.num_temporal_patch, embed_dim)
+ )
+ if self.cls_embed_on:
+ self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ else:
+ self.pos_embed_class = torch.tensor([]) # for torchscriptability
+ self.pos_embed = torch.tensor([])
+
+ else:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ # Placeholders for torchscriptability, won't be used
+ self.pos_embed_spatial = torch.tensor([])
+ self.pos_embed_temporal = torch.tensor([])
+ self.pos_embed_class = torch.tensor([])
+
+ @torch.jit.export
+ def patch_embed_shape(self) -> Tuple[int, int, int]:
+ return self._patch_embed_shape
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): Input tensor.
+ """
+ B, N, C = x.shape
+ if self.cls_embed_on:
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.sep_pos_embed:
+ pos_embed = self.pos_embed_spatial.repeat(
+ 1, self.num_temporal_patch, 1
+ ) + torch.repeat_interleave(
+ self.pos_embed_temporal,
+ self.num_spatial_patch,
+ dim=1,
+ )
+ if self.cls_embed_on:
+ pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1)
+ x = x + pos_embed
+ else:
+ x = x + self.pos_embed
+
+ return x
+
+
+def get_3d_sincos_pos_embed(
+ embed_dim: int, grid_size: int, t_size: int, cls_token: bool = False
+) -> torch.Tensor:
+ """
+ Get 3D sine-cosine positional embedding.
+ Args:
+ grid_size: int of the grid height and width
+ t_size: int of the temporal size
+ cls_token: bool, whether to contain CLS token
+ Returns:
+ (torch.Tensor): [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ assert embed_dim % 4 == 0
+ embed_dim_spatial = embed_dim // 4 * 3
+ embed_dim_temporal = embed_dim // 4
+
+ # spatial
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h)
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
+
+ # temporal
+ grid_t = np.arange(t_size, dtype=np.float32)
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
+
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(pos_embed_temporal, grid_size**2, axis=1)
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0)
+
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
+ pos_embed = pos_embed.reshape([-1, embed_dim])
+
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim: int, grid_size: int, cls_token: bool = False
+) -> torch.Tensor:
+ """
+ Get 2D sine-cosine positional embedding.
+ Args:
+ grid_size: int of the grid height and width
+ cls_token: bool, whether to contain CLS token
+ Returns:
+ (torch.Tensor): [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h)
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> torch.Tensor:
+ """
+ Get 2D sine-cosine positional embedding from grid.
+ Args:
+ embed_dim: embedding dimension.
+ grid: positions
+ Returns:
+ (torch.Tensor): [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+
+ """
+ assert embed_dim % 2 == 0
+
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
+
+ emb = np.concatenate([emb_h, emb_w], axis=1)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> torch.Tensor:
+ """
+ Get 1D sine-cosine positional embedding.
+ Args:
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ Returns:
+ (torch.Tensor): tensor of shape (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega
+
+ pos = pos.reshape(-1)
+ out = np.einsum("m,d->md", pos, omega)
+
+ emb_sin = np.sin(out)
+ emb_cos = np.cos(out)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
+ return emb
diff --git a/code/pytorchvideo/pytorchvideo/layers/positional_encoding_torchscript.py b/code/pytorchvideo/pytorchvideo/layers/positional_encoding_torchscript.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f3d328da4feef45a84b44af25933265fd5f030f
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/positional_encoding_torchscript.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class ScriptableSpatioTemporalClsPositionalEncoding(nn.Module):
+ """
+ Add a cls token and apply a spatiotemporal encoding to a tensor.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ patch_embed_shape: Tuple[int, int, int],
+ sep_pos_embed: bool = False,
+ has_cls: bool = True,
+ ) -> None:
+ """
+ Args:
+ embed_dim (int): Embedding dimension for input sequence.
+ patch_embed_shape (Tuple): The number of patches in each dimension
+ (T, H, W) after patch embedding.
+ sep_pos_embed (bool): If set to true, one positional encoding is used for
+ spatial patches and another positional encoding is used for temporal
+ sequence. Otherwise, only one positional encoding is used for all the
+ patches.
+ has_cls (bool): If set to true, a cls token is added in the beginning of each
+ input sequence.
+ """
+ super().__init__()
+ assert (
+ len(patch_embed_shape) == 3
+ ), "Patch_embed_shape should be in the form of (T, H, W)."
+ assert not has_cls
+ self.sep_pos_embed = sep_pos_embed
+ self._patch_embed_shape = patch_embed_shape
+ self.num_spatial_patch = patch_embed_shape[1] * patch_embed_shape[2]
+ self.num_temporal_patch = patch_embed_shape[0]
+
+ self.pos_embed_spatial = nn.Parameter(
+ torch.zeros(1, self.num_spatial_patch, embed_dim)
+ )
+ self.pos_embed_temporal = nn.Parameter(
+ torch.zeros(1, self.num_temporal_patch, embed_dim)
+ )
+
+ @property
+ def patch_embed_shape(self):
+ return self._patch_embed_shape
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): Input tensor.
+ """
+ B, N, C = x.shape
+
+ assert self.sep_pos_embed
+ pos_embed = self.pos_embed_spatial.repeat(
+ 1, self.num_temporal_patch, 1
+ ) + torch.repeat_interleave(
+ self.pos_embed_temporal,
+ self.num_spatial_patch,
+ dim=1,
+ )
+ x = x + pos_embed
+ return x
diff --git a/code/pytorchvideo/pytorchvideo/layers/squeeze_excitation.py b/code/pytorchvideo/pytorchvideo/layers/squeeze_excitation.py
new file mode 100644
index 0000000000000000000000000000000000000000..47858e0a48cbbbd36a4645aa7921ed3c06b327f5
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/squeeze_excitation.py
@@ -0,0 +1,182 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+from pytorchvideo.models.resnet import ResBlock
+
+
+class SqueezeAndExcitationLayer2D(nn.Module):
+ """2D Squeeze and excitation layer, as per https://arxiv.org/pdf/1709.01507.pdf"""
+
+ def __init__(
+ self,
+ in_planes: int,
+ reduction_ratio: Optional[int] = 16,
+ reduced_planes: Optional[int] = None,
+ ):
+
+ """
+ Args:
+ in_planes (int): input channel dimension.
+ reduction_ratio (int): factor by which in_planes should be reduced to
+ get the output channel dimension.
+ reduced_planes (int): Output channel dimension. Only one of reduction_ratio
+ or reduced_planes should be defined.
+ """
+ super().__init__()
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ # Either reduction_ratio is defined, or out_planes is defined
+ assert bool(reduction_ratio) != bool(
+ reduced_planes
+ ), "Only of reduction_ratio or reduced_planes should be defined for SE layer"
+
+ reduced_planes = (
+ in_planes // reduction_ratio if reduced_planes is None else reduced_planes
+ )
+ self.excitation = nn.Sequential(
+ nn.Conv2d(in_planes, reduced_planes, kernel_size=1, stride=1, bias=True),
+ nn.ReLU(),
+ nn.Conv2d(reduced_planes, in_planes, kernel_size=1, stride=1, bias=True),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (tensor): 2D image of format C * H * W
+ """
+ x_squeezed = self.avgpool(x)
+ x_excited = self.excitation(x_squeezed)
+ x_scaled = x * x_excited
+ return x_scaled
+
+
+def create_audio_2d_squeeze_excitation_block(
+ dim_in: int,
+ dim_out: int,
+ use_se=False,
+ se_reduction_ratio=16,
+ branch_fusion: Callable = lambda x, y: x + y,
+ # Conv configs.
+ conv_a_kernel_size: int = 3,
+ conv_a_stride: int = 1,
+ conv_a_padding: int = 1,
+ conv_b_kernel_size: int = 3,
+ conv_b_stride: int = 1,
+ conv_b_padding: int = 1,
+ # Norm configs.
+ norm: Callable = nn.BatchNorm2d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+
+ """
+ 2-D Residual block with squeeze excitation (SE2D) for 2d. Performs a summation between an
+ identity shortcut in branch1 and a main block in branch2. When the input and
+ output dimensions are different, a convolution followed by a normalization
+ will be performed.
+
+ ::
+
+ Input
+ |-------+
+ ↓ |
+ conv2d |
+ ↓ |
+ Norm |
+ ↓ |
+ activation |
+ ↓ |
+ conv2d |
+ ↓ |
+ Norm |
+ ↓ |
+ SE2D |
+ ↓ }
+ Summation ←-+
+ ↓
+ Activation
+
+ Normalization examples include: BatchNorm3d and None (no normalization).
+ Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
+ Transform examples include: BottleneckBlock.
+
+ Args:
+ dim_in (int): input channel size to the bottleneck block.
+ dim_out (int): output channel size of the bottleneck.
+ use_se (bool): if true, use squeeze excitation layer in the bottleneck.
+ se_reduction_ratio (int): factor by which input channels should be reduced to
+ get the output channel dimension in SE layer.
+ branch_fusion (callable): a callable that constructs summation layer.
+ Examples include: lambda x, y: x + y, OctaveSum.
+
+ conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ conv_a_stride (tuple): convolutional stride size(s) for conv_a.
+ conv_a_padding (tuple): convolutional padding(s) for conv_a.
+ conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_b_stride (tuple): convolutional stride size(s) for conv_b.
+ conv_b_padding (tuple): convolutional padding(s) for conv_b.
+
+ norm (callable): a callable that constructs normalization layer. Examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer in
+ bottleneck and block. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid,
+ and None (not performing activation).
+
+ Returns:
+ (nn.Module): resnet basic block layer.
+ """
+
+ branch2 = [
+ nn.Conv2d(
+ dim_in,
+ dim_out,
+ kernel_size=conv_a_kernel_size,
+ stride=conv_a_stride,
+ padding=conv_a_padding,
+ bias=False,
+ ),
+ norm(dim_out, norm_eps, norm_momentum),
+ activation() if activation else nn.Identity(),
+ nn.Conv2d(
+ dim_out,
+ dim_out,
+ kernel_size=conv_b_kernel_size,
+ stride=conv_b_stride,
+ padding=conv_b_padding,
+ bias=False,
+ ),
+ norm(dim_out, norm_eps, norm_momentum),
+ ]
+ if use_se:
+ branch2.append(
+ SqueezeAndExcitationLayer2D(dim_out, reduction_ratio=se_reduction_ratio)
+ )
+ branch2 = nn.Sequential(*branch2)
+
+ branch1_conv, branch1_norm = None, None
+ if conv_a_stride * conv_b_stride != 1 or dim_in != dim_out:
+ branch1_conv = nn.Conv2d(
+ dim_in,
+ dim_out,
+ kernel_size=1,
+ stride=conv_a_stride * conv_b_stride,
+ bias=False,
+ )
+ branch1_norm = norm(dim_out, norm_eps, norm_momentum)
+
+ return ResBlock(
+ branch1_conv=branch1_conv,
+ branch1_norm=branch1_norm,
+ branch2=branch2,
+ activation=activation() if activation else None,
+ branch_fusion=branch_fusion,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/layers/swish.py b/code/pytorchvideo/pytorchvideo/layers/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..21bdcece546aa7ce5f6c0fa666de0c7e4fb99409
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/swish.py
@@ -0,0 +1,34 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import torch
+import torch.nn as nn
+
+
+class Swish(nn.Module):
+ """
+ Wrapper for the Swish activation function.
+ """
+
+ def forward(self, x):
+ return SwishFunction.apply(x)
+
+
+class SwishFunction(torch.autograd.Function):
+ """
+ Implementation of the Swish activation function: x * sigmoid(x).
+
+ Searching for activation functions. Ramachandran, Prajit and Zoph, Barret
+ and Le, Quoc V. 2017
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ result = x * torch.sigmoid(x)
+ ctx.save_for_backward(x)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_variables[0]
+ sigmoid_x = torch.sigmoid(x)
+ return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
diff --git a/code/pytorchvideo/pytorchvideo/layers/utils.py b/code/pytorchvideo/pytorchvideo/layers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..15593d61e44286033a999c6ff2729038df406516
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/layers/utils.py
@@ -0,0 +1,49 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import math
+from typing import List
+
+
+def set_attributes(self, params: List[object] = None) -> None:
+ """
+ An utility function used in classes to set attributes from the input list of parameters.
+ Args:
+ params (list): list of parameters.
+ """
+ if params:
+ for k, v in params.items():
+ if k != "self":
+ setattr(self, k, v)
+
+
+def round_width(width, multiplier, min_width=8, divisor=8, ceil=False):
+ """
+ Round width of filters based on width multiplier
+ Args:
+ width (int): the channel dimensions of the input.
+ multiplier (float): the multiplication factor.
+ min_width (int): the minimum width after multiplication.
+ divisor (int): the new width should be dividable by divisor.
+ ceil (bool): If True, use ceiling as the rounding method.
+ """
+ if not multiplier:
+ return width
+
+ width *= multiplier
+ min_width = min_width or divisor
+ if ceil:
+ width_out = max(min_width, int(math.ceil(width / divisor)) * divisor)
+ else:
+ width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
+ if width_out < 0.9 * width:
+ width_out += divisor
+ return int(width_out)
+
+
+def round_repeats(repeats, multiplier):
+ """
+ Round number of layers based on depth multiplier.
+ """
+ if not multiplier:
+ return repeats
+ return int(math.ceil(multiplier * repeats))
diff --git a/code/pytorchvideo/pytorchvideo/losses/__init__.py b/code/pytorchvideo/pytorchvideo/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/code/pytorchvideo/pytorchvideo/losses/soft_target_cross_entropy.py b/code/pytorchvideo/pytorchvideo/losses/soft_target_cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..d51259a45d67592467420f88bca322753d0ebea1
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/losses/soft_target_cross_entropy.py
@@ -0,0 +1,81 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pytorchvideo.layers.utils import set_attributes
+from pytorchvideo.transforms.functional import convert_to_one_hot
+
+
+class SoftTargetCrossEntropyLoss(nn.Module):
+ """
+ Adapted from Classy Vision: ./classy_vision/losses/soft_target_cross_entropy_loss.py.
+ This allows the targets for the cross entropy loss to be multi-label.
+ """
+
+ def __init__(
+ self,
+ ignore_index: int = -100,
+ reduction: str = "mean",
+ normalize_targets: bool = True,
+ ) -> None:
+ """
+ Args:
+ ignore_index (int): sample should be ignored for loss if the class is this value.
+ reduction (str): specifies reduction to apply to the output.
+ normalize_targets (bool): whether the targets should be normalized to a sum of 1
+ based on the total count of positive targets for a given sample.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert isinstance(self.normalize_targets, bool)
+ if self.reduction not in ["mean", "none"]:
+ raise NotImplementedError(
+ 'reduction type "{}" not implemented'.format(self.reduction)
+ )
+ self.eps = torch.finfo(torch.float32).eps
+
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ input (torch.Tensor): the shape of the tensor is N x C, where N is the number of
+ samples and C is the number of classes. The tensor is raw input without
+ softmax/sigmoid.
+ target (torch.Tensor): the shape of the tensor is N x C or N. If the shape is N, we
+ will convert the target to one hot vectors.
+ """
+ # Check if targets are inputted as class integers
+ if target.ndim == 1:
+ assert (
+ input.shape[0] == target.shape[0]
+ ), "SoftTargetCrossEntropyLoss requires input and target to have same batch size!"
+ target = convert_to_one_hot(target.view(-1, 1), input.shape[1])
+
+ assert input.shape == target.shape, (
+ "SoftTargetCrossEntropyLoss requires input and target to be same "
+ f"shape: {input.shape} != {target.shape}"
+ )
+
+ # Samples where the targets are ignore_index do not contribute to the loss
+ N, C = target.shape
+ valid_mask = torch.ones((N, 1), dtype=torch.float).to(input.device)
+ if 0 <= self.ignore_index <= C - 1:
+ drop_idx = target[:, self.ignore_idx] > 0
+ valid_mask[drop_idx] = 0
+
+ valid_targets = target.float() * valid_mask
+ if self.normalize_targets:
+ valid_targets /= self.eps + valid_targets.sum(dim=1, keepdim=True)
+ per_sample_per_target_loss = -valid_targets * F.log_softmax(input, -1)
+
+ per_sample_loss = torch.sum(per_sample_per_target_loss, -1)
+ # Perform reduction
+ if self.reduction == "mean":
+ # Normalize based on the number of samples with > 0 non-ignored targets
+ loss = per_sample_loss.sum() / torch.sum(
+ (torch.sum(valid_mask, -1) > 0)
+ ).clamp(min=1)
+ elif self.reduction == "none":
+ loss = per_sample_loss
+
+ return loss
diff --git a/code/pytorchvideo/pytorchvideo/models/__init__.py b/code/pytorchvideo/pytorchvideo/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a8c4a4a44685d6a79b81f781be09bfb40fd32e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .csn import create_csn
+from .head import create_res_basic_head, ResNetBasicHead
+from .masked_multistream import (
+ LearnMaskedDefault,
+ LSTM,
+ MaskedMultiPathWay,
+ MaskedSequential,
+ MaskedTemporalPooling,
+ TransposeMultiheadAttention,
+ TransposeTransformerEncoder,
+)
+from .net import MultiPathWayWithFuse, Net
+from .resnet import BottleneckBlock, create_bottleneck_block, create_resnet
+from .slowfast import create_slowfast
+from .stem import create_conv_patch_embed, create_res_basic_stem, ResNetBasicStem
+from .vision_transformers import create_multiscale_vision_transformers
+from .weight_init import init_net_weights
diff --git a/code/pytorchvideo/pytorchvideo/models/accelerator/__init__.py b/code/pytorchvideo/pytorchvideo/models/accelerator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/accelerator/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/__init__.py b/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/efficient_x3d.py b/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/efficient_x3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e3525856babc7a4d650df05a0fe49f60e58590
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/efficient_x3d.py
@@ -0,0 +1,206 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from collections import OrderedDict
+
+import torch.nn as nn
+from pytorchvideo.layers.accelerator.mobile_cpu.activation_functions import (
+ supported_act_functions,
+)
+from pytorchvideo.layers.accelerator.mobile_cpu.convolutions import (
+ Conv3d5x1x1BnAct,
+ Conv3dPwBnAct,
+ Conv3dTemporalKernel1BnAct,
+)
+from pytorchvideo.layers.accelerator.mobile_cpu.fully_connected import FullyConnected
+from pytorchvideo.layers.accelerator.mobile_cpu.pool import AdaptiveAvgPool3dOutSize1
+
+from .residual_blocks import X3dBottleneckBlock
+
+
+class EfficientX3d(nn.Module):
+ """
+ This class implements an X3D network for classification with efficient blocks.
+ Args:
+ num_classes (int): Number of classes in classification.
+ dropout (float): Dropout rate used for training the network.
+ expansion (str): Expansion for X3D. Possible options: 'XS', 'S', 'M', 'L'.
+ head_act (str): The activation function to be applied in head, should be a key
+ in dict supported_act_functions (see activation_functions.py for more info
+ about supported activations).
+ enable_head (bool): Whether X3D model provides head.
+ """
+
+ def __init__(
+ self,
+ num_classes: int = 400,
+ dropout: float = 0.5,
+ expansion: str = "XS",
+ head_act: str = "identity",
+ enable_head: bool = True,
+ ):
+ super().__init__()
+ assert expansion in (
+ "XS",
+ "S",
+ "M",
+ "L",
+ ), f"Expansion {expansion} not supported."
+ # s1 - stem
+ s1 = OrderedDict()
+ s1["pathway0_stem_conv_xy"] = Conv3dTemporalKernel1BnAct(
+ 3,
+ 24,
+ bias=False,
+ groups=1,
+ spatial_kernel=3,
+ spatial_stride=2,
+ spatial_padding=1,
+ activation="identity",
+ use_bn=False,
+ )
+ s1["pathway0_stem_conv"] = Conv3d5x1x1BnAct(
+ 24,
+ 24,
+ bias=False,
+ groups=24,
+ use_bn=True,
+ )
+ self.s1 = nn.Sequential(s1)
+ # s2 - res2
+ s2 = OrderedDict()
+ depth_s2 = 5 if expansion == "L" else 3
+ for i_block in range(depth_s2):
+ cur_block = X3dBottleneckBlock(
+ in_channels=24,
+ mid_channels=54,
+ out_channels=24,
+ use_residual=True,
+ spatial_stride=(2 if i_block == 0 else 1),
+ se_ratio=(0.0625 if (i_block % 2) == 0 else 0),
+ act_functions=("relu", "swish", "relu"),
+ use_bn=(True, True, True),
+ )
+ s2[f"pathway0_res{i_block}"] = cur_block
+ self.s2 = nn.Sequential(s2)
+ # s3 - res3
+ s3 = OrderedDict()
+ depth_s3 = 10 if expansion == "L" else 5
+ for i_block in range(depth_s3):
+ cur_block = X3dBottleneckBlock(
+ in_channels=(24 if i_block == 0 else 48),
+ mid_channels=108,
+ out_channels=48,
+ use_residual=True,
+ spatial_stride=(2 if i_block == 0 else 1),
+ se_ratio=(0.0625 if (i_block % 2) == 0 else 0),
+ act_functions=("relu", "swish", "relu"),
+ use_bn=(True, True, True),
+ )
+ s3[f"pathway0_res{i_block}"] = cur_block
+ self.s3 = nn.Sequential(s3)
+ # s4 - res4
+ s4 = OrderedDict()
+ depth_s4 = 25 if expansion == "L" else 11
+ for i_block in range(depth_s4):
+ cur_block = X3dBottleneckBlock(
+ in_channels=(48 if i_block == 0 else 96),
+ mid_channels=216,
+ out_channels=96,
+ use_residual=True,
+ spatial_stride=(2 if i_block == 0 else 1),
+ se_ratio=(0.0625 if (i_block % 2) == 0 else 0),
+ act_functions=("relu", "swish", "relu"),
+ use_bn=(True, True, True),
+ )
+ s4[f"pathway0_res{i_block}"] = cur_block
+ self.s4 = nn.Sequential(s4)
+ # s5 - res5
+ s5 = OrderedDict()
+ depth_s5 = 15 if expansion == "L" else 7
+ for i_block in range(depth_s5):
+ cur_block = X3dBottleneckBlock(
+ in_channels=(96 if i_block == 0 else 192),
+ mid_channels=432,
+ out_channels=192,
+ use_residual=True,
+ spatial_stride=(2 if i_block == 0 else 1),
+ se_ratio=(0.0625 if (i_block % 2) == 0 else 0),
+ act_functions=("relu", "swish", "relu"),
+ use_bn=(True, True, True),
+ )
+ s5[f"pathway0_res{i_block}"] = cur_block
+ self.s5 = nn.Sequential(s5)
+ self.enable_head = enable_head
+ if enable_head:
+ # head
+ head = OrderedDict()
+ head["conv_5"] = Conv3dPwBnAct(
+ in_channels=192,
+ out_channels=432,
+ bias=False,
+ use_bn=True,
+ )
+ head["avg_pool"] = AdaptiveAvgPool3dOutSize1()
+ head["lin_5"] = Conv3dPwBnAct(
+ in_channels=432,
+ out_channels=2048,
+ bias=False,
+ use_bn=False,
+ )
+ self.head = nn.Sequential(head)
+ if dropout > 0:
+ self.dropout = nn.Dropout(dropout)
+ self.projection = FullyConnected(2048, num_classes, bias=True)
+ assert head_act in supported_act_functions, f"{head_act} is not supported."
+ self.act = supported_act_functions[head_act]()
+
+ def forward(self, x):
+ x = self.s1(x)
+ x = self.s2(x)
+ x = self.s3(x)
+ x = self.s4(x)
+ x = self.s5(x)
+ if self.enable_head:
+ x = self.head(x)
+ # (N, C, T, H, W) -> (N, T, H, W, C).
+ x = x.permute((0, 2, 3, 4, 1))
+ if hasattr(self, "dropout"):
+ x = self.dropout(x)
+ x = self.projection(x)
+ # Performs fully convlutional inference.
+ if not self.training:
+ x = self.act(x)
+ x = x.mean([1, 2, 3])
+ x = x.view(x.shape[0], -1)
+
+ return x
+
+
+def create_x3d(
+ *,
+ # EfficientX3d model arguments.
+ num_classes: int = 400,
+ dropout: float = 0.5,
+ expansion: str = "XS",
+ head_act: str = "identity",
+ enable_head: bool = True,
+):
+ """
+ This function builds a X3D network with efficient blocks.
+ Args:
+ num_classes (int): Number of classes in classification.
+ dropout (float): Dropout rate used for training the network.
+ expansion (str): Expansion for X3D. Possible options: 'XS', 'S', 'M', 'L'.
+ head_act (str): The activation function to be applied in head, should be a key
+ in dict supported_act_functions (see activation_functions.py for more info
+ about supported activations). Currently ReLU ('relu'), Swish ('swish'),
+ Hardswish ('hswish'), Identity ('identity') are supported.
+ enable_head (bool): Whether X3D model provides head.
+ """
+ return EfficientX3d(
+ num_classes=num_classes,
+ dropout=dropout,
+ expansion=expansion,
+ head_act=head_act,
+ enable_head=enable_head,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/residual_blocks.py b/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/residual_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..557dad73321238a7e37c99ca0454921e1418bbf4
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/accelerator/mobile_cpu/residual_blocks.py
@@ -0,0 +1,239 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from collections import OrderedDict
+from typing import Optional, Tuple
+
+import torch.nn as nn
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+from pytorchvideo.layers.accelerator.mobile_cpu.activation_functions import (
+ supported_act_functions,
+)
+from pytorchvideo.layers.accelerator.mobile_cpu.attention import SqueezeExcitation
+from pytorchvideo.layers.accelerator.mobile_cpu.convolutions import (
+ Conv3d3x3x3DwBnAct,
+ Conv3dPwBnAct,
+ Conv3dTemporalKernel1BnAct,
+)
+from pytorchvideo.layers.utils import round_width
+
+
+class X3dBottleneckBlock(EfficientBlockBase):
+ """
+ Implements a X3D style residual block with optional squeeze-excite (SE)
+ using efficient blocks.
+
+ Input +----------------------+
+ | |
+ v |
+ conv3d[0] (1x1x1) |
+ | |
+ v |
+ batchNorm (optional) |
+ | |
+ v |
+ activation[0] |
+ | |
+ v |
+ conv3d[1] (3x3x3 dw) |
+ | |
+ v |
+ batchNorm (optional) |
+ | |
+ v |
+ Squeeze-Excite (optional) |
+ | |
+ v |
+ activation[1] |
+ | |
+ v |
+ conv3d[2] (1x1x1) |
+ | |
+ v |
+ batchNorm (optional) |
+ | |
+ v |
+ sum <-----------------------+
+ |
+ v
+ activation[2]
+
+ Args:
+ in_channels (int): input channels for for 1x1x1 conv3d[0].
+ mid_channels (int): channels for 3x3x3 dw conv3d[1].
+ out_channels (int): output channels for 1x1x1 conv3d[2].
+ spatial_stride (int): spatial stride for 3x3x3 dw conv3d[1].
+ se_ratio (float): if > 0, apply SE to the 3x3x3 dw conv3d[1], with the SE
+ channel dimensionality being se_ratio times the 3x3x3 conv dim.
+ bias (tuple of bool): if bias[i] is true, use bias for conv3d[i].
+ act_functions (tuple of str): act_functions[i] is the activation function after
+ conv3d[i]. act_functions[i] should be a key in dict supported_act_functions
+ (see activation_functions.py for more info about supported activations).
+ Currently ReLU ('relu'), Swish ('swish'), Hardswish ('hswish'), Identity
+ ('identity') are supported.
+ use_bn (tuple of bool): if use_bn[i] is true, use batchnorm after conv3d[i].
+ norm_eps (float): epsilon for batchnorm.
+ norm_momentum (float): momentum for batchnorm.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ mid_channels: int,
+ out_channels: int,
+ use_residual: bool = True,
+ spatial_stride: int = 1,
+ se_ratio: float = 0.0625,
+ act_functions: Optional[Tuple[str]] = ("relu", "relu", "relu"),
+ bias: Optional[Tuple[bool]] = (False, False, False),
+ use_bn: Optional[Tuple[bool]] = (True, True, True),
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ ):
+ super().__init__()
+
+ # Residual projection
+ self._use_residual = use_residual
+ self._res_proj = None
+ if self._use_residual:
+ self._residual_add_func = nn.quantized.FloatFunctional()
+ if (spatial_stride != 1) or (in_channels != out_channels):
+ self._res_proj = Conv3dTemporalKernel1BnAct(
+ in_channels,
+ out_channels,
+ bias=False,
+ groups=1,
+ spatial_kernel=1,
+ spatial_stride=spatial_stride,
+ spatial_padding=0,
+ spatial_dilation=1,
+ activation="identity",
+ use_bn=True,
+ )
+
+ layers = OrderedDict()
+
+ # 1x1x1 pointwise layer conv[0]
+ assert (
+ act_functions[0] in supported_act_functions
+ ), f"{act_functions[0]} is not supported."
+ layers["conv_0"] = Conv3dPwBnAct(
+ in_channels,
+ mid_channels,
+ bias=bias[0],
+ # If activation function is relu, just include that in convBnRelu block.
+ activation=act_functions[0],
+ use_bn=use_bn[0],
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ )
+
+ # 3x3x3 dw layer conv[1]
+ self._spatial_stride = spatial_stride
+ self._mid_channels = mid_channels
+ assert (
+ act_functions[1] in supported_act_functions
+ ), f"{act_functions[1]} is not supported."
+ layers["conv_1"] = Conv3d3x3x3DwBnAct(
+ mid_channels,
+ spatial_stride=self._spatial_stride,
+ bias=bias[1],
+ activation="identity", # Will apply activation after SE.
+ use_bn=use_bn[1],
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ )
+ if se_ratio > 0:
+ layers["se"] = SqueezeExcitation(
+ num_channels=mid_channels,
+ num_channels_reduced=round_width(mid_channels, se_ratio),
+ is_3d=True,
+ )
+ # Add activation function if act_functions[1].
+ layers["act_func_1"] = supported_act_functions[act_functions[1]]()
+
+ # Second 1x1x1 pointwise layer conv[2]
+ self._out_channels = out_channels
+ assert (
+ act_functions[2] in supported_act_functions
+ ), f"{act_functions[2]} is not supported."
+ layers["conv_2"] = Conv3dPwBnAct(
+ mid_channels,
+ out_channels,
+ bias=bias[2],
+ # With residual, apply activation function externally after residual sum.
+ activation="identity",
+ use_bn=use_bn[2],
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ )
+ self.final_act = supported_act_functions[act_functions[2]]()
+
+ self.layers = nn.Sequential(layers)
+
+ self.convert_flag = False
+
+ def forward(self, x):
+ out = self.layers(x)
+ if self._use_residual:
+ if self._res_proj is not None:
+ x = self._res_proj(x)
+ out = self._residual_add_func.add(x, out)
+ out = self.final_act(out)
+ return out
+
+ def convert(
+ self,
+ input_blob_size,
+ *args,
+ convert_for_quantize=False,
+ native_conv3d_op_qnnpack=False,
+ **kwargs,
+ ):
+ assert (
+ self.convert_flag is False
+ ), "X3dBottleneckBlock: already converted, cannot be converted twice"
+
+ # Convert self.layers
+ batch_size = input_blob_size[0]
+ THW_size = tuple(input_blob_size[2:])
+ if self._res_proj is not None:
+ self._res_proj.convert(
+ input_blob_size,
+ convert_for_quantize=convert_for_quantize,
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ self.layers.conv_0.convert(
+ input_blob_size,
+ convert_for_quantize=convert_for_quantize,
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ # Update input_blob_size when necessary after each layer
+ input_blob_size = (batch_size, self._mid_channels) + THW_size
+
+ self.layers.conv_1.convert(
+ input_blob_size,
+ convert_for_quantize=convert_for_quantize,
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ THW_size = (
+ THW_size[0],
+ THW_size[1] // self._spatial_stride,
+ THW_size[2] // self._spatial_stride,
+ )
+ input_blob_size = (batch_size, self._mid_channels) + THW_size
+ if hasattr(self.layers, "se"):
+ self.layers.se.convert(
+ input_blob_size
+ ) # No need to change as SE is using linear
+ self.layers.act_func_1.convert(input_blob_size)
+ self.layers.conv_2.convert(
+ input_blob_size,
+ convert_for_quantize=convert_for_quantize,
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ input_blob_size = (batch_size, self._out_channels) + THW_size
+ self.final_act.convert(input_blob_size)
+ self.convert_flag = True
diff --git a/code/pytorchvideo/pytorchvideo/models/audio_visual_slowfast.py b/code/pytorchvideo/pytorchvideo/models/audio_visual_slowfast.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0dd78119b6cac775bbcb27ae71f1f83bad91170
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/audio_visual_slowfast.py
@@ -0,0 +1,418 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Tuple, Union
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.utils import set_attributes
+from pytorchvideo.models.resnet import (
+ create_acoustic_bottleneck_block,
+ create_bottleneck_block,
+)
+from pytorchvideo.models.slowfast import create_slowfast
+from pytorchvideo.models.stem import (
+ create_acoustic_res_basic_stem,
+ create_res_basic_stem,
+)
+
+
+# Note we expect audio data as (Time, 1, Frequency)
+def create_audio_visual_slowfast(
+ *,
+ # SlowFast configs.
+ slowfast_channel_reduction_ratio: Tuple[int] = (8, 2),
+ slowfast_conv_channel_fusion_ratio: int = 2,
+ fusion_builder: Callable[
+ [int, int], nn.Module
+ ] = None, # Args: fusion_dim_in, stage_idx
+ # Input clip configs.
+ input_channels: Tuple[int] = (3, 3, 1),
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 400,
+ dropout_rate: float = 0.5,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_dim_outs: Tuple[int] = (64, 8, 32),
+ stem_conv_kernel_sizes: Tuple[Tuple[int]] = ((1, 7, 7), (5, 7, 7), (9, 1, 9)),
+ stem_conv_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2), (1, 1, 1)),
+ stem_pool: Tuple[Callable] = (nn.MaxPool3d, nn.MaxPool3d, None),
+ stem_pool_kernel_sizes: Tuple[Tuple[int]] = ((1, 3, 3), (1, 3, 3), (1, 3, 3)),
+ stem_pool_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2), (1, 1, 1)),
+ # Stage configs.
+ stage_conv_a_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 1, 1), (1, 1, 1), (3, 1, 1), (3, 1, 1)),
+ ((3, 1, 1), (3, 1, 1), (3, 1, 1), (3, 1, 1)),
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
+ ),
+ stage_conv_b_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
+ ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
+ ((3, 1, 3), (3, 1, 3), (3, 1, 3), (3, 1, 3)),
+ ),
+ stage_conv_b_num_groups: Tuple[Tuple[int]] = (
+ (1, 1, 1, 1),
+ (1, 1, 1, 1),
+ (1, 1, 1, 1),
+ ),
+ stage_conv_b_dilations: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
+ ),
+ stage_spatial_strides: Tuple[Tuple[int]] = (
+ (1, 2, 2, 2),
+ (1, 2, 2, 2),
+ (1, 2, 2, 2),
+ ),
+ stage_temporal_strides: Tuple[Tuple[int]] = (
+ (1, 1, 1, 1),
+ (1, 1, 1, 1),
+ (1, 2, 2, 2),
+ ),
+ bottleneck: Tuple[Tuple[Callable]] = (
+ (
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ (
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ (
+ create_acoustic_bottleneck_block,
+ create_acoustic_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ ),
+ # Head configs.
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_sizes: Tuple[Tuple[int]] = ((8, 7, 7), (32, 7, 7), (16, 1, 10)),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = None,
+ head_output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Model builder for AVSlowFast network.
+ Fanyi Xiao, Yong Jae Lee, Kristen Grauman, Jitendra Malik, Christoph Feichtenhofer.
+ "Audiovisual SlowFast Networks for Video Recognition."
+ https://arxiv.org/abs/2001.08740
+
+ Slow Input Fast Input Audio Input
+ ↓ ↓ ↓
+ Stem Stem Stem
+ ↓ ⭠ Fusion- ↓ ⭠ Fusion- ↓
+ Stage 1 Stage 1 Stage 1
+ ↓ ⭠ Fusion- ↓ ⭠ Fusion- ↓
+ . . .
+ ↓ ↓ ↓
+ Stage N Stage N Stage N
+ ↓ ⭠ Fusion- ↓ ⭠ Fusion- ↓
+ ↓
+ Head
+
+ Args:
+ SlowFast configs:
+ slowfast_channel_reduction_ratio (int): Corresponds to the inverse of the channel
+ reduction ratio, $\beta$F between the Slow and Fast pathways.
+ slowfast_audio_reduction_ratio (int): Corresponds to the inverse of the channel
+ reduction ratio, $\beta$A between the Slow and Audio pathways.
+ slowfast_conv_channel_fusion_ratio (int): Ratio of channel dimensions
+ between the Slow and Fast pathways.
+ fusion_builder (Callable[[int, int], nn.Module]): Builder function for generating
+ the fusion modules based on stage dimension and index
+
+ Input clip configs:
+ input_channels (tuple): number of channels for the input video clip.
+
+ Model configs:
+ model_depth (int): the depth of the resnet.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+ Normalization configs:
+ norm (callable): a callable that constructs normalization layer.
+
+ Activation configs:
+ activation (callable): a callable that constructs activation layer.
+
+ Stem configs:
+ stem_function (Tuple[Callable]): a callable that constructs stem layer.
+ Examples include create_res_basic_stem. Indexed by pathway
+ stem_dim_outs (tuple): output channel size to stem.
+ stem_conv_kernel_sizes (tuple): convolutional kernel size(s) of stem.
+ stem_conv_strides (tuple): convolutional stride size(s) of stem.
+ stem_pool (Tuple[Callable]): a callable that constructs resnet head pooling layer.
+ Indexed by pathway
+ stem_pool_kernel_sizes (tuple): pooling kernel size(s).
+ stem_pool_strides (tuple): pooling stride size(s).
+
+ Stage configs:
+ stage_conv_a_kernel_sizes (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_sizes (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
+ for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
+ stage_conv_b_dilations (tuple): dilation for 3D convolution for conv_b.
+ stage_spatial_strides (tuple): the spatial stride for each stage.
+ stage_temporal_strides (tuple): the temporal stride for each stage.
+ bottleneck (Tuple[Tuple[Callable]]): a callable that constructs bottleneck
+ block layer. Examples include: create_bottleneck_block.
+ Indexed by pathway and stage index
+
+ Head configs:
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_output_sizes (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+ Returns:
+ (nn.Module): SlowFast model.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_audio_visual_slowfast")
+
+ # Number of blocks for different stages given the model depth.
+ # 3 pathways, first is slow, second is fast, third is audio
+ if fusion_builder is None:
+ fusion_builder = AudioToSlowFastFusionBuilder(
+ slowfast_channel_reduction_ratio=slowfast_channel_reduction_ratio[0],
+ slowfast_audio_reduction_ratio=slowfast_channel_reduction_ratio[1],
+ conv_fusion_channel_ratio=slowfast_conv_channel_fusion_ratio,
+ conv_kernel_size=(7, 1, 1),
+ conv_kernel_size_a=(5, 1, 1),
+ conv_stride=(4, 1, 1),
+ conv_stride_a=((16, 1, 1), (16, 1, 1), (8, 1, 1), (4, 1, 1), (2, 1, 1)),
+ norm=norm,
+ activation=activation,
+ ).create_module
+
+ return create_slowfast(
+ slowfast_channel_reduction_ratio=slowfast_channel_reduction_ratio,
+ slowfast_conv_channel_fusion_ratio=slowfast_conv_channel_fusion_ratio,
+ fusion_builder=fusion_builder,
+ # Input clip configs.
+ input_channels=input_channels,
+ # Model configs.
+ model_depth=model_depth,
+ model_num_class=model_num_class,
+ dropout_rate=dropout_rate,
+ # Normalization configs.
+ norm=norm,
+ # Activation configs.
+ activation=activation,
+ # Stem configs.
+ stem_function=(
+ create_res_basic_stem,
+ create_res_basic_stem,
+ create_acoustic_res_basic_stem,
+ ),
+ stem_dim_outs=stem_dim_outs,
+ stem_conv_kernel_sizes=stem_conv_kernel_sizes,
+ stem_conv_strides=stem_conv_strides,
+ stem_pool=stem_pool,
+ stem_pool_kernel_sizes=stem_pool_kernel_sizes,
+ stem_pool_strides=stem_pool_strides,
+ # Stage configs.
+ stage_conv_a_kernel_sizes=stage_conv_a_kernel_sizes,
+ stage_conv_b_kernel_sizes=stage_conv_b_kernel_sizes,
+ stage_conv_b_num_groups=stage_conv_b_num_groups,
+ stage_conv_b_dilations=stage_conv_b_dilations,
+ stage_spatial_strides=stage_spatial_strides,
+ stage_temporal_strides=stage_temporal_strides,
+ bottleneck=bottleneck,
+ # Head configs.
+ head_pool=head_pool,
+ head_pool_kernel_sizes=head_pool_kernel_sizes,
+ head_output_size=head_output_size,
+ head_activation=head_activation,
+ head_output_with_global_average=head_output_with_global_average,
+ )
+
+
+class AudioToSlowFastFusionBuilder:
+ def __init__(
+ self,
+ slowfast_channel_reduction_ratio: int,
+ slowfast_audio_reduction_ratio: int,
+ conv_fusion_channel_ratio: float,
+ conv_kernel_size: Tuple[int],
+ conv_kernel_size_a: Tuple[int],
+ conv_stride: Union[Tuple[int], Tuple[Tuple[int]]],
+ conv_stride_a: Union[Tuple[int], Tuple[Tuple[int]]],
+ conv_fusion_channel_interm_dim: Union[int, float] = 0.25, # also, 64
+ conv_num_a: int = 2,
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ activation: Callable = nn.ReLU,
+ max_stage_idx: int = 3,
+ ) -> None:
+ """
+ Given a list of two tensors from Slow pathway and Fast pathway, fusion information
+ from the Fast pathway to the Slow on through a convolution followed by a
+ concatenation, then return the fused list of tensors from Slow and Fast pathway in
+ order.
+ Args:
+ slowfast_channel_reduction_ratio (int): Reduction ratio from the stage dimension.
+ Used to compute conv_dim_in = fusion_dim_in // slowfast_channel_reduction_ratio
+ slowfast_audio_reduction_ratio (int): Audio Reduction ratio from the stage dimension.
+ Used to compute conv_dim_in_a = fusion_dim_in // slowfast_audio_reduction_ratio
+ conv_fusion_channel_ratio (int): channel ratio for the convolution used to fuse
+ from Fast pathway to Slow pathway.
+ conv_kernel_size (int): kernel size of the convolution used to fuse from Fast
+ pathway to Slow pathway.
+ conv_kernel_size_a (int): kernel size of the convolution used to fuse from Audio
+ pathway to FastSlow pathway.
+ conv_stride (int): stride size of the convolution used to fuse from Fast pathway
+ to Slow pathway. Optionally indexed by stage.
+ conv_stride_a (int): stride size of the convolution used to fuse from Audio pathway
+ to FastSlow pathway. Optionally indexed by stage.
+ conv_fusion_channel_interm_dim (Union[int, float]): When conv_num_a > 1 this value
+ controls the dimensions of the intermediate conv
+ conv_num_a (int): Number of intermediate conv for audio channel
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+ max_stage_idx (int): Returns identity module if we exceed this
+ """
+ set_attributes(self, locals())
+
+ def create_module(self, fusion_dim_in: int, stage_idx: int) -> nn.Module:
+ """
+ Creates the module for the given stage
+ Args:
+ fusion_dim_in (int): input stage dimension
+ stage_idx (int): which stage this is
+ """
+ if stage_idx > self.max_stage_idx:
+ return nn.Identity()
+
+ conv_stride = (
+ self.conv_stride[stage_idx]
+ if isinstance(self.conv_stride[0], Tuple)
+ else self.conv_stride
+ )
+ conv_stride_a = (
+ self.conv_stride_a[stage_idx]
+ if isinstance(self.conv_stride_a[0], Tuple)
+ else self.conv_stride_a
+ )
+
+ conv_dim_in = fusion_dim_in // self.slowfast_channel_reduction_ratio
+ conv_dim_in_a = fusion_dim_in // self.slowfast_audio_reduction_ratio
+ fastslow_module = []
+ fastslow_module.append(
+ nn.Conv3d(
+ conv_dim_in,
+ int(conv_dim_in * self.conv_fusion_channel_ratio),
+ kernel_size=self.conv_kernel_size,
+ stride=conv_stride,
+ padding=[k_size // 2 for k_size in self.conv_kernel_size],
+ bias=False,
+ )
+ )
+ if self.norm is not None:
+ fastslow_module.append(
+ self.norm(
+ num_features=conv_dim_in * self.conv_fusion_channel_ratio,
+ eps=self.norm_eps,
+ momentum=self.norm_momentum,
+ )
+ )
+ if self.activation is not None:
+ fastslow_module.append(self.activation())
+
+ if isinstance(self.conv_fusion_channel_interm_dim, int):
+ afs_fusion_interm_dim = self.conv_fusion_channel_interm_dim
+ else:
+ afs_fusion_interm_dim = int(
+ conv_dim_in_a * self.conv_fusion_channel_interm_dim
+ )
+
+ block_audio_to_fastslow = []
+ cur_dim_in = conv_dim_in_a
+ for idx in range(self.conv_num_a):
+ if idx == self.conv_num_a - 1:
+ cur_stride = conv_stride_a
+ cur_dim_out = int(
+ conv_dim_in * self.conv_fusion_channel_ratio + fusion_dim_in
+ )
+ else:
+ cur_stride = (1, 1, 1)
+ cur_dim_out = afs_fusion_interm_dim
+
+ block_audio_to_fastslow.append(
+ nn.Conv3d(
+ cur_dim_in,
+ cur_dim_out,
+ kernel_size=self.conv_kernel_size_a,
+ stride=cur_stride,
+ padding=[k_size // 2 for k_size in self.conv_kernel_size_a],
+ bias=False,
+ )
+ )
+ if self.norm is not None:
+ block_audio_to_fastslow.append(
+ self.norm(
+ num_features=cur_dim_out,
+ eps=self.norm_eps,
+ momentum=self.norm_momentum,
+ )
+ )
+ if self.activation is not None:
+ block_audio_to_fastslow.append(self.activation())
+ cur_dim_in = cur_dim_out
+
+ return FuseAudioToFastSlow(
+ block_fast_to_slow=nn.Sequential(*fastslow_module),
+ block_audio_to_fastslow=nn.Sequential(*block_audio_to_fastslow),
+ )
+
+
+class FuseAudioToFastSlow(nn.Module):
+ """
+ Given a list of two tensors from Slow pathway and Fast pathway, fusion information
+ from the Fast pathway to the Slow on through a convolution followed by a
+ concatenation, then return the fused list of tensors from Slow and Fast pathway in
+ order.
+ """
+
+ def __init__(
+ self,
+ block_fast_to_slow: nn.Module,
+ block_audio_to_fastslow: nn.Module,
+ ) -> None:
+ """
+ Args:
+ conv_fast_to_slow (nn.module): convolution to perform fusion.
+ norm (nn.module): normalization module.
+ activation (torch.nn.modules): activation module.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+
+ def forward(self, x):
+ x_s = x[0]
+ x_f = x[1]
+ x_a = x[2]
+ fuse = self.block_fast_to_slow(x_f)
+
+ # Reduce frequency dim
+ average_a = torch.mean(x_a, dim=-1, keepdim=True)
+ fuse_a = self.block_audio_to_fastslow(average_a)
+
+ x_s_fuse = torch.cat([x_s, fuse], 1)
+ print(x_s_fuse.size())
+ return [fuse_a + x_s_fuse, x_f, x_a]
diff --git a/code/pytorchvideo/pytorchvideo/models/byol.py b/code/pytorchvideo/pytorchvideo/models/byol.py
new file mode 100644
index 0000000000000000000000000000000000000000..f419d948a5c63e9b2f3d1f66be89df88bfc67675
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/byol.py
@@ -0,0 +1,143 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import copy
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BYOL(nn.Module):
+ """
+ Bootstrap Your Own Latent A New Approach to Self-Supervised Learning
+ Details can be found in:
+ https://arxiv.org/pdf/2006.07733.pdf
+ """
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ projector: Optional[nn.Module] = None,
+ predictor: Optional[nn.Module] = None,
+ feature_dim: int = 2048,
+ predictor_inner: int = 4096,
+ mmt: float = 0.99,
+ norm: Callable = nn.SyncBatchNorm,
+ ) -> None:
+ """
+ Args:
+ backbone (nn.Module): backbone for byol, input shape depends on the forward
+ input size. Standard inputs include `B x C`, `B x C x H x W`, and
+ `B x C x T x H x W`.
+ projector (nn.Module): stand projector is a mlp with 2 to 3 hidden layers,
+ with (synchronized) BatchNorm and ReLU activation.
+ predictor (nn.Module): predictor MLP of BYOL of similar structure as the
+ projector MLP.
+ feature_dim (int): output feature dimension.
+ predictor_inner (int): inner channel size for predictor.
+ mmt (float): momentum update ratio for the momentum backbone.
+ norm (callable): normalization to be used in projector, default is
+ synchronized batchnorm.
+ """
+ super().__init__()
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.BYOL.__init__")
+
+ self.mmt = mmt
+ self.feature_dim = feature_dim
+ if projector is not None:
+ backbone = nn.Sequential(
+ backbone,
+ projector,
+ )
+ self.backbone = backbone
+ self.backbone_mmt = copy.deepcopy(backbone)
+ for p in self.backbone_mmt.parameters():
+ p.requires_grad = False
+ if predictor is None:
+ self.predictor = nn.Sequential(
+ nn.Linear(feature_dim, predictor_inner, bias=False),
+ norm(predictor_inner),
+ nn.ReLU(inplace=True),
+ nn.Linear(predictor_inner, feature_dim, bias=True),
+ )
+ else:
+ self.predictor = predictor
+
+ def sim_loss(self, q, k):
+ """
+ Similarity loss for byol.
+ Args:
+ q and k (nn.tensor): inputs to calculate the similarity, expected to have
+ the same shape of `N x C`.
+ """
+ similarity = torch.einsum("nc,nc->n", [q, k])
+ loss = -similarity.mean()
+ return loss
+
+ def update_mmt(self, mmt: float):
+ """
+ Update the momentum. This function can be used to perform momentum annealing.
+ Args:
+ mmt (float): update the momentum.
+ """
+ self.mmt = mmt
+
+ def get_mmt(self) -> float:
+ """
+ Get the momentum. This function can be used to perform momentum annealing.
+ """
+ return self.mmt
+
+ @torch.no_grad()
+ def _momentum_update_backbone(self):
+ """
+ Momentum update on the backbone.
+ """
+ for param, param_mmt in zip(
+ self.backbone.parameters(), self.backbone_mmt.parameters()
+ ):
+ param_mmt.data = param_mmt.data * self.mmt + param.data * (1.0 - self.mmt)
+
+ @torch.no_grad()
+ def forward_backbone_mmt(self, x):
+ """
+ Forward momentum backbone.
+ Args:
+ x (tensor): input to be forwarded.
+ """
+ with torch.no_grad():
+ proj = self.backbone_mmt(x)
+ return F.normalize(proj, dim=1)
+
+ def forward_backbone(self, x):
+ """
+ Forward backbone.
+ Args:
+ x (tensor): input to be forwarded.
+ """
+ proj = self.backbone(x)
+ pred = self.predictor(proj)
+ return F.normalize(pred, dim=1)
+
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x1 (torch.tensor): a batch of image with augmentation. The input tensor
+ shape should able to be feed into the backbone.
+ x2 (torch.tensor): the size batch of image with different augmentation. The
+ input tensor shape should able to be feed into the backbone.
+ """
+ pred_1 = self.forward_backbone(x1)
+ pred_2 = self.forward_backbone(x2)
+
+ with torch.no_grad():
+ self._momentum_update_backbone()
+ proj_mmt_1 = self.forward_backbone_mmt(x1)
+ proj_mmt_2 = self.forward_backbone_mmt(x2)
+
+ loss = (
+ self.sim_loss(pred_1, proj_mmt_2) + self.sim_loss(pred_2, proj_mmt_1)
+ ) / 2
+ return loss
diff --git a/code/pytorchvideo/pytorchvideo/models/csn.py b/code/pytorchvideo/pytorchvideo/models/csn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a154c09ba935f12d5dd646b6dd8ba620e7e661e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/csn.py
@@ -0,0 +1,191 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.models.head import create_res_basic_head
+from pytorchvideo.models.resnet import create_bottleneck_block, create_res_stage, Net
+from pytorchvideo.models.stem import create_res_basic_stem
+
+
+def create_csn(
+ *,
+ # Input clip configs.
+ input_channel: int = 3,
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 400,
+ dropout_rate: float = 0,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_dim_out: int = 64,
+ stem_conv_kernel_size: Tuple[int] = (3, 7, 7),
+ stem_conv_stride: Tuple[int] = (1, 2, 2),
+ stem_pool: Callable = None,
+ stem_pool_kernel_size: Tuple[int] = (1, 3, 3),
+ stem_pool_stride: Tuple[int] = (1, 2, 2),
+ # Stage configs.
+ stage_conv_a_kernel_size: Tuple[int] = (1, 1, 1),
+ stage_conv_b_kernel_size: Tuple[int] = (3, 3, 3),
+ stage_conv_b_width_per_group: int = 1,
+ stage_spatial_stride: Tuple[int] = (1, 2, 2, 2),
+ stage_temporal_stride: Tuple[int] = (1, 2, 2, 2),
+ bottleneck: Callable = create_bottleneck_block,
+ bottleneck_ratio: int = 4,
+ # Head configs.
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_size: Tuple[int] = (1, 7, 7),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = None,
+ head_output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Build Channel-Separated Convolutional Networks (CSN):
+ Video classification with channel-separated convolutional networks.
+ Du Tran, Heng Wang, Lorenzo Torresani, Matt Feiszli. ICCV 2019.
+
+ CSN follows the ResNet style architecture including three parts: Stem,
+ Stages and Head. The three parts are assembled in the following order:
+
+ ::
+
+ Input
+ ↓
+ Stem
+ ↓
+ Stage 1
+ ↓
+ .
+ .
+ .
+ ↓
+ Stage N
+ ↓
+ Head
+
+ CSN uses depthwise convolution. To further reduce the computational cost, it uses
+ low resolution (112x112), short clips (4 frames), different striding and kernel
+ size, etc.
+
+ Args:
+
+ input_channel (int): number of channels for the input video clip.
+
+ model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+ norm (callable): a callable that constructs normalization layer.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_dim_out (int): output channel size to stem.
+ stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
+ stem_conv_stride (tuple): convolutional stride size(s) of stem.
+ stem_pool (callable): a callable that constructs resnet head pooling layer.
+ stem_pool_kernel_size (tuple): pooling kernel size(s).
+ stem_pool_stride (tuple): pooling stride size(s).
+
+ stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_width_per_group(int): the width of each group for conv_b. Set
+ it to 1 for depthwise convolution.
+ stage_spatial_stride (tuple): the spatial stride for each stage.
+ stage_temporal_stride (tuple): the temporal stride for each stage.
+ bottleneck (callable): a callable that constructs bottleneck block layer.
+ Examples include: create_bottleneck_block.
+ bottleneck_ratio (int): the ratio between inner and outer dimensions for
+ the bottleneck block.
+
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_pool_kernel_size (tuple): the pooling kernel size.
+ head_output_size (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+
+ Returns:
+ (nn.Module): the csn model.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_csn")
+
+ # Number of blocks for different stages given the model depth.
+ _MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
+
+ # Given a model depth, get the number of blocks for each stage.
+ assert (
+ model_depth in _MODEL_STAGE_DEPTH.keys()
+ ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
+ stage_depths = _MODEL_STAGE_DEPTH[model_depth]
+
+ blocks = []
+ # Create stem for CSN.
+ stem = create_res_basic_stem(
+ in_channels=input_channel,
+ out_channels=stem_dim_out,
+ conv_kernel_size=stem_conv_kernel_size,
+ conv_stride=stem_conv_stride,
+ conv_padding=[size // 2 for size in stem_conv_kernel_size],
+ pool=stem_pool,
+ pool_kernel_size=stem_pool_kernel_size,
+ pool_stride=stem_pool_stride,
+ pool_padding=[size // 2 for size in stem_pool_kernel_size],
+ norm=norm,
+ activation=activation,
+ )
+ blocks.append(stem)
+
+ stage_dim_in = stem_dim_out
+ stage_dim_out = stage_dim_in * 4
+
+ # Create each stage for CSN.
+ for idx in range(len(stage_depths)):
+ stage_dim_inner = stage_dim_out // bottleneck_ratio
+ depth = stage_depths[idx]
+
+ stage_conv_b_stride = (
+ stage_temporal_stride[idx],
+ stage_spatial_stride[idx],
+ stage_spatial_stride[idx],
+ )
+
+ stage = create_res_stage(
+ depth=depth,
+ dim_in=stage_dim_in,
+ dim_inner=stage_dim_inner,
+ dim_out=stage_dim_out,
+ bottleneck=bottleneck,
+ conv_a_kernel_size=stage_conv_a_kernel_size,
+ conv_a_stride=(1, 1, 1),
+ conv_a_padding=[size // 2 for size in stage_conv_a_kernel_size],
+ conv_b_kernel_size=stage_conv_b_kernel_size,
+ conv_b_stride=stage_conv_b_stride,
+ conv_b_padding=[size // 2 for size in stage_conv_b_kernel_size],
+ conv_b_num_groups=(stage_dim_inner // stage_conv_b_width_per_group),
+ conv_b_dilation=(1, 1, 1),
+ norm=norm,
+ activation=activation,
+ )
+
+ blocks.append(stage)
+ stage_dim_in = stage_dim_out
+ stage_dim_out = stage_dim_out * 2
+
+ # Create head for CSN.
+ head = create_res_basic_head(
+ in_features=stage_dim_in,
+ out_features=model_num_class,
+ pool=head_pool,
+ output_size=head_output_size,
+ pool_kernel_size=head_pool_kernel_size,
+ dropout_rate=dropout_rate,
+ activation=head_activation,
+ output_with_global_average=head_output_with_global_average,
+ )
+ blocks.append(head)
+ return Net(blocks=nn.ModuleList(blocks))
diff --git a/code/pytorchvideo/pytorchvideo/models/head.py b/code/pytorchvideo/pytorchvideo/models/head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b7f7a741e2facc898b62fb47750344ec10afb07
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/head.py
@@ -0,0 +1,535 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.utils import set_attributes
+from torchvision.ops import RoIAlign
+
+
+class SequencePool(nn.Module):
+ """
+ Sequence pool produces a single embedding from a sequence of embeddings. Currently
+ it supports "mean" and "cls".
+
+ """
+
+ def __init__(self, mode: str) -> None:
+ """
+ Args:
+ mode (str): Optionals include "cls" and "mean". If set to "cls", it assumes
+ the first element in the input is the cls token and returns it. If set
+ to "mean", it returns the mean of the entire sequence.
+ """
+ super().__init__()
+ assert mode in ["cls", "mean"], "Unsupported mode for SequencePool."
+ self.mode = mode
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.mode == "cls":
+ x = x[:, 0]
+ elif self.mode == "mean":
+ x = x.mean(1)
+ else:
+ raise NotImplementedError
+ return x
+
+
+def create_res_basic_head(
+ *,
+ # Projection configs.
+ in_features: int,
+ out_features: int,
+ # Pooling configs.
+ pool: Callable = nn.AvgPool3d,
+ output_size: Tuple[int] = (1, 1, 1),
+ pool_kernel_size: Tuple[int] = (1, 7, 7),
+ pool_stride: Tuple[int] = (1, 1, 1),
+ pool_padding: Tuple[int] = (0, 0, 0),
+ # Dropout configs.
+ dropout_rate: float = 0.5,
+ # Activation configs.
+ activation: Callable = None,
+ # Output configs.
+ output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Creates ResNet basic head. This layer performs an optional pooling operation
+ followed by an optional dropout, a fully-connected projection, an activation layer
+ and a global spatiotemporal averaging.
+
+ ::
+
+
+ Pooling
+ ↓
+ Dropout
+ ↓
+ Projection
+ ↓
+ Activation
+ ↓
+ Averaging
+
+ Activation examples include: ReLU, Softmax, Sigmoid, and None.
+ Pool3d examples include: AvgPool3d, MaxPool3d, AdaptiveAvgPool3d, and None.
+
+ Args:
+
+ in_features: input channel size of the resnet head.
+ out_features: output channel size of the resnet head.
+
+ pool (callable): a callable that constructs resnet head pooling layer,
+ examples include: nn.AvgPool3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d, and
+ None (not applying pooling).
+ pool_kernel_size (tuple): pooling kernel size(s) when not using adaptive
+ pooling.
+ pool_stride (tuple): pooling stride size(s) when not using adaptive pooling.
+ pool_padding (tuple): pooling padding size(s) when not using adaptive
+ pooling.
+ output_size (tuple): spatial temporal output size when using adaptive
+ pooling.
+
+ activation (callable): a callable that constructs resnet head activation
+ layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not
+ applying activation).
+
+ dropout_rate (float): dropout rate.
+
+ output_with_global_average (bool): if True, perform global averaging on temporal
+ and spatial dimensions and reshape output to batch_size x out_features.
+ """
+
+ if activation is None:
+ activation_model = None
+ elif activation == nn.Softmax:
+ activation_model = activation(dim=1)
+ else:
+ activation_model = activation()
+
+ if pool is None:
+ pool_model = None
+ elif pool == nn.AdaptiveAvgPool3d:
+ pool_model = pool(output_size)
+ else:
+ pool_model = pool(
+ kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
+ )
+
+ if output_with_global_average:
+ output_pool = nn.AdaptiveAvgPool3d(1)
+ else:
+ output_pool = None
+
+ return ResNetBasicHead(
+ proj=nn.Linear(in_features, out_features),
+ activation=activation_model,
+ pool=pool_model,
+ dropout=nn.Dropout(dropout_rate) if dropout_rate > 0 else None,
+ output_pool=output_pool,
+ )
+
+
+def create_vit_basic_head(
+ *,
+ # Projection configs.
+ in_features: int,
+ out_features: int,
+ # Pooling configs.
+ seq_pool_type: str = "cls",
+ # Dropout configs.
+ dropout_rate: float = 0.5,
+ # Activation configs.
+ activation: Callable = None,
+) -> nn.Module:
+ """
+ Creates vision transformer basic head.
+
+ ::
+
+
+ Pooling
+ ↓
+ Dropout
+ ↓
+ Projection
+ ↓
+ Activation
+
+
+ Activation examples include: ReLU, Softmax, Sigmoid, and None.
+ Pool type examples include: cls, mean and none.
+
+ Args:
+
+ in_features: input channel size of the resnet head.
+ out_features: output channel size of the resnet head.
+
+ pool_type (str): Pooling type. It supports "cls", "mean " and "none". If set to
+ "cls", it assumes the first element in the input is the cls token and
+ returns it. If set to "mean", it returns the mean of the entire sequence.
+
+ activation (callable): a callable that constructs vision transformer head
+ activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and
+ None (not applying activation).
+
+ dropout_rate (float): dropout rate.
+ """
+ assert seq_pool_type in ["cls", "mean", "none"]
+
+ if seq_pool_type in ["cls", "mean"]:
+ seq_pool_model = SequencePool(seq_pool_type)
+ elif seq_pool_type == "none":
+ seq_pool_model = None
+ else:
+ raise NotImplementedError
+
+ if activation is None:
+ activation_model = None
+ elif activation == nn.Softmax:
+ activation_model = activation(dim=1)
+ else:
+ activation_model = activation()
+
+ return VisionTransformerBasicHead(
+ sequence_pool=seq_pool_model,
+ dropout=nn.Dropout(dropout_rate) if dropout_rate > 0.0 else None,
+ proj=nn.Linear(in_features, out_features),
+ activation=activation_model,
+ )
+
+
+def create_res_roi_pooling_head(
+ *,
+ # Projection configs.
+ in_features: int,
+ out_features: int,
+ # RoI configs.
+ resolution: Tuple,
+ spatial_scale: float,
+ sampling_ratio: int = 0,
+ roi: Callable = RoIAlign,
+ # Pooling configs.
+ pool: Callable = nn.AvgPool3d,
+ output_size: Tuple[int] = (1, 1, 1),
+ pool_kernel_size: Tuple[int] = (1, 7, 7),
+ pool_stride: Tuple[int] = (1, 1, 1),
+ pool_padding: Tuple[int] = (0, 0, 0),
+ pool_spatial: Callable = nn.MaxPool2d,
+ # Dropout configs.
+ dropout_rate: float = 0.5,
+ # Activation configs.
+ activation: Callable = None,
+ # Output configs.
+ output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Creates ResNet RoI head. This layer performs an optional pooling operation
+ followed by an RoI projection, an optional 2D spatial pool, an optional dropout,
+ a fully-connected projection, an activation layer
+ and a global spatiotemporal averaging.
+
+ Pool3d
+ ↓
+ RoI Align
+ ↓
+ Pool2d
+ ↓
+ Dropout
+ ↓
+ Projection
+ ↓
+ Activation
+ ↓
+ Averaging
+
+ Activation examples include: ReLU, Softmax, Sigmoid, and None.
+ Pool3d examples include: AvgPool3d, MaxPool3d, AdaptiveAvgPool3d, and None.
+ RoI examples include: detectron2.layers.ROIAlign, detectron2.layers.ROIAlignRotated,
+ tochvision.ops.RoIAlign and None
+ Pool2d examples include: MaxPool2e, AvgPool2d, and None.
+
+ Args:
+ Projection related configs:
+ in_features: input channel size of the resnet head.
+ out_features: output channel size of the resnet head.
+
+ RoI layer related configs:
+ resolution (tuple): h, w sizes of the RoI interpolation.
+ spatial_scale (float): scale the input boxes by this number
+ sampling_ratio (int): number of inputs samples to take for each output
+ sample interpolation. 0 to take samples densely.
+ roi (callable): a callable that constructs the roi interpolation layer,
+ examples include detectron2.layers.ROIAlign,
+ detectron2.layers.ROIAlignRotated, and None.
+
+ Pooling related configs:
+ pool (callable): a callable that constructs resnet head pooling layer,
+ examples include: nn.AvgPool3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d, and
+ None (not applying pooling).
+ pool_kernel_size (tuple): pooling kernel size(s) when not using adaptive
+ pooling.
+ pool_stride (tuple): pooling stride size(s) when not using adaptive pooling.
+ pool_padding (tuple): pooling padding size(s) when not using adaptive
+ pooling.
+ output_size (tuple): spatial temporal output size when using adaptive
+ pooling.
+ pool_spatial (callable): a callable that constructs the 2d pooling layer which
+ follows the RoI layer, examples include: nn.AvgPool2d, nn.MaxPool2d, and
+ None (not applying spatial pooling).
+
+ Activation related configs:
+ activation (callable): a callable that constructs resnet head activation
+ layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not
+ applying activation).
+
+ Dropout related configs:
+ dropout_rate (float): dropout rate.
+
+ Output related configs:
+ output_with_global_average (bool): if True, perform global averaging on temporal
+ and spatial dimensions and reshape output to batch_size x out_features.
+ """
+ if activation is None:
+ activation_model = None
+ elif activation == nn.Softmax:
+ activation_model = activation(dim=1)
+ else:
+ activation_model = activation()
+
+ if pool is None:
+ pool_model = None
+ elif pool == nn.AdaptiveAvgPool3d:
+ pool_model = pool(output_size)
+ else:
+ pool_model = pool(
+ kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
+ )
+
+ if output_with_global_average:
+ output_pool = nn.AdaptiveAvgPool3d(1)
+ else:
+ output_pool = None
+
+ return ResNetRoIHead(
+ proj=nn.Linear(in_features, out_features),
+ activation=activation_model,
+ pool=pool_model,
+ pool_spatial=pool_spatial(resolution, stride=1) if pool_spatial else None,
+ roi_layer=roi(
+ output_size=resolution,
+ spatial_scale=spatial_scale,
+ sampling_ratio=sampling_ratio,
+ ),
+ dropout=nn.Dropout(dropout_rate) if dropout_rate > 0 else None,
+ output_pool=output_pool,
+ )
+
+
+class ResNetBasicHead(nn.Module):
+ """
+ ResNet basic head. This layer performs an optional pooling operation followed by an
+ optional dropout, a fully-connected projection, an optional activation layer and a
+ global spatiotemporal averaging.
+
+ ::
+
+ Pool3d
+ ↓
+ Dropout
+ ↓
+ Projection
+ ↓
+ Activation
+ ↓
+ Averaging
+
+ The builder can be found in `create_res_basic_head`.
+ """
+
+ def __init__(
+ self,
+ pool: nn.Module = None,
+ dropout: nn.Module = None,
+ proj: nn.Module = None,
+ activation: nn.Module = None,
+ output_pool: nn.Module = None,
+ ) -> None:
+ """
+ Args:
+ pool (torch.nn.modules): pooling module.
+ dropout(torch.nn.modules): dropout module.
+ proj (torch.nn.modules): project module.
+ activation (torch.nn.modules): activation module.
+ output_pool (torch.nn.Module): pooling module for output.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.proj is not None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Performs pooling.
+ if self.pool is not None:
+ x = self.pool(x)
+ # Performs dropout.
+ if self.dropout is not None:
+ x = self.dropout(x)
+ # Performs projection.
+ if self.proj is not None:
+ x = x.permute((0, 2, 3, 4, 1))
+ x = self.proj(x)
+ x = x.permute((0, 4, 1, 2, 3))
+ # Performs activation.
+ if self.activation is not None:
+ x = self.activation(x)
+
+ if self.output_pool is not None:
+ # Performs global averaging.
+ x = self.output_pool(x)
+ x = x.view(x.shape[0], -1)
+ return x
+
+
+class ResNetRoIHead(nn.Module):
+ """
+ ResNet RoI head. This layer performs an optional pooling operation
+ followed by an RoI projection, an optional 2D spatial pool, an optional dropout,
+ a fully-connected projection, an activation layer
+ and a global spatiotemporal averaging.
+ Pool3d
+ ↓
+ RoI Align
+ ↓
+ Pool2d
+ ↓
+ Dropout
+ ↓
+ Projection
+ ↓
+ Activation
+ ↓
+ Averaging
+
+ The builder can be found in `create_res_roi_pooling_head`.
+ """
+
+ def __init__(
+ self,
+ pool: nn.Module = None,
+ pool_spatial: nn.Module = None,
+ roi_layer: nn.Module = None,
+ dropout: nn.Module = None,
+ proj: nn.Module = None,
+ activation: nn.Module = None,
+ output_pool: nn.Module = None,
+ ) -> None:
+ """
+ Args:
+ pool (torch.nn.modules): pooling module.
+ pool_spatial (torch.nn.modules): pooling module.
+ roi_spatial (torch.nn.modules): RoI (Ex: Align, pool) module.
+ dropout(torch.nn.modules): dropout module.
+ proj (torch.nn.modules): project module.
+ activation (torch.nn.modules): activation module.
+ output_pool (torch.nn.Module): pooling module for output.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.proj is not None
+
+ def forward(self, x: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.tensor): input tensor
+ bboxes (torch.tensor): Accociated bounding boxes.
+ The format is N*5 (Index, X_1,Y_1,X_2,Y_2) if using RoIAlign
+ and N*6 (Index, x_ctr, y_ctr, width, height, angle_degrees) if
+ using RoIAlignRotated.
+ """
+ # Performs 3d pooling.
+ if self.pool is not None:
+ x = self.pool(x)
+ # Performs roi layer using bboxes
+ if self.roi_layer is not None:
+ temporal_dim = x.shape[-3]
+ if temporal_dim != 1:
+ raise Exception(
+ "Temporal dimension should be 1. Consider modifying the pool layer."
+ )
+ x = torch.squeeze(x, -3)
+ x = self.roi_layer(x, bboxes)
+ # Performs spatial 2d pooling.
+ if self.pool_spatial is not None:
+ x = self.pool_spatial(x)
+ x = x.unsqueeze(-3)
+ # Performs dropout.
+ if self.dropout is not None:
+ x = self.dropout(x)
+ # Performs projection.
+ if self.proj is not None:
+ x = x.permute((0, 2, 3, 4, 1))
+ x = self.proj(x)
+ x = x.permute((0, 4, 1, 2, 3))
+ # Performs activation.
+ if self.activation is not None:
+ x = self.activation(x)
+
+ if self.output_pool is not None:
+ # Performs global averaging.
+ x = self.output_pool(x)
+ x = x.view(x.shape[0], -1)
+ return x
+
+
+class VisionTransformerBasicHead(nn.Module):
+ """
+ Vision transformer basic head.
+
+ ::
+
+ SequencePool
+ ↓
+ Dropout
+ ↓
+ Projection
+ ↓
+ Activation
+
+
+ The builder can be found in `create_vit_basic_head`.
+ """
+
+ def __init__(
+ self,
+ sequence_pool: nn.Module = None,
+ dropout: nn.Module = None,
+ proj: nn.Module = None,
+ activation: nn.Module = None,
+ ) -> None:
+ """
+ Args:
+ sequence_pool (torch.nn.modules): pooling module.
+ dropout(torch.nn.modules): dropout module.
+ proj (torch.nn.modules): project module.
+ activation (torch.nn.modules): activation module.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.proj is not None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Performs pooling.
+ if self.sequence_pool is not None:
+ x = self.sequence_pool(x)
+
+ # Performs dropout.
+ if self.dropout is not None:
+ x = self.dropout(x)
+ # Performs projection.
+ if self.proj is not None:
+ x = self.proj(x)
+ # Performs activation.
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/README.md b/code/pytorchvideo/pytorchvideo/models/hub/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..80dc2050a26867692993867fe73d64f1b5f27c5e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/README.md
@@ -0,0 +1,48 @@
+## TorchHub Models
+
+PyTorchVideo provides a large set of [TorchHub](https://pytorch.org/hub/) models for state-of-the-art models with pre-trained weights. Check the tables below for the torchhub names and corresponding models.
+
+
+### Kinetics-400
+
+Models are trained on Kinetics-400. For more benchmarking and model details, please check the [PyTorchVideo Model Zoo](https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md)
+
+torchhub name | arch | depth | frame length x sample rate | top 1 | top 5 |
+------------------------ | -------- | ----- | -------------------------- | ----- | ----- |
+c2d_r50 | C2D | R50 | 8x8 | 71.46 | 89.68 |
+i3d_r50 | I3D | R50 | 8x8 | 73.27 | 90.70 |
+slow_r50 | Slow | R50 | 8x8 | 74.58 | 91.63 |
+slowfast_r50 | SlowFast | R50 | 8x8 | 76.94 | 92.69 |
+slowfast_r101 | SlowFast | R101 | 8x8 | 77.90 | 93.27 |
+slowfast_16x8_r101_50_50 | SlowFast | R101 | 16x8 | 78.70 | 93.61 |
+csn_r101 | CSN | R101 | 32x2 | 77.00 | 92.90 |
+r2plus1d_r50 | R(2+1)D | R50 | 16x4 | 76.01 | 92.23 |
+x3d_xs | X3D | XS | 4x12 | 69.12 | 88.63 |
+x3d_s | X3D | S | 13x6 | 73.33 | 91.27 |
+x3d_m | X3D | M | 16x5 | 75.94 | 92.72 |
+x3d_l | X3D | L | 16x5 | 77.44 | 93.31 |
+
+### PytorchVideo Accelerator Models
+
+**Efficient Models for mobile CPU**
+Models are trained on Kinetics-400. Latency is benchmarked on Samsung S8 phone with 1s input clip length.
+
+torchhub name | model | top 1 | top 5 | latency (ms) |
+---------------- |--------|-------|-------|--------------|
+efficient_x3d_xs | X3D_XS | 68.5 | 88.0 | 233 |
+efficient_x3d_s | X3D_S | 73.0 | 90.6 | 764 |
+
+
+
+### Using PyTorchVideo torchhub models
+The models have been integrated into TorchHub, so could be loaded with TorchHub with or without pre-trained models. You can specify the torchhub name for the model to construct the model with pre-trained weights:
+
+```Python
+# Pick a pretrained model
+model_name = "slowfast_r50"
+model = torch.hub.load("facebookresearch/pytorchvideo:main", model=model_name, pretrained=True)
+```
+
+Notes:
+* Please check [torchhub inference tutorial](https://pytorchvideo.org/docs/tutorial_torchhub_inference) for more details about how to load models from TorchHub and perform inference.
+* Check [Model Zoo](https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md) for the full set of supported PytorchVideo model zoo and more details about how the model zoo is prepared.
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/__init__.py b/code/pytorchvideo/pytorchvideo/models/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..218ead33b720986939ec75e1d433455df95ecff0
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .csn import csn_r101
+from .efficient_x3d_mobile_cpu import efficient_x3d_s, efficient_x3d_xs
+from .r2plus1d import r2plus1d_r50
+from .resnet import c2d_r50, i3d_r50, slow_r50, slow_r50_detection
+from .slowfast import (
+ slowfast_16x8_r101_50_50,
+ slowfast_r101,
+ slowfast_r50,
+ slowfast_r50_detection,
+)
+from .vision_transformers import mvit_base_16, mvit_base_16x4, mvit_base_32x3
+from .x3d import x3d_l, x3d_m, x3d_s, x3d_xs
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/csn.py b/code/pytorchvideo/pytorchvideo/models/hub/csn.py
new file mode 100644
index 0000000000000000000000000000000000000000..49666c44bef0c26b1371eced264b12a7e1ba4fae
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/csn.py
@@ -0,0 +1,58 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any
+
+import torch.nn as nn
+from pytorchvideo.models.csn import create_csn
+from torch.hub import load_state_dict_from_url
+
+
+"""
+Channel-Separated Convolutional Network models for video recognition.
+"""
+
+root_dir = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics"
+checkpoint_paths = {
+ "csn_r101": f"{root_dir}/CSN_32x2_R101.pyth",
+}
+
+
+def csn_r101(
+ pretrained: bool = False, progress: bool = True, **kwargs: Any
+) -> nn.Module:
+ r"""
+ Channel-Separated Convolutional Networks (CSN) R101 model architecture [1]
+ with pretrained weights based on 32x2 setting on the Kinetics dataset.
+ Model with pretrained weights has top1 accuracy of 77.0 (trained on 16x8 GPUs).
+
+ [1] "Video classification with channel-separated convolutional networks"
+ Du Tran, Heng Wang, Lorenzo Torresani, Matt Feiszli. ICCV 2019.
+ https://arxiv.org/abs/1904.02811
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/resnet.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ model = create_csn(
+ model_depth=101,
+ stem_pool=nn.MaxPool3d,
+ head_pool_kernel_size=(4, 7, 7),
+ **kwargs,
+ )
+
+ if pretrained:
+ path = checkpoint_paths["csn_r101"]
+ # All models are loaded onto CPU by default
+ checkpoint = load_state_dict_from_url(
+ path, progress=progress, map_location="cpu"
+ )
+ state_dict = checkpoint["model_state"]
+ model.load_state_dict(state_dict)
+
+ return model
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/efficient_x3d_mobile_cpu.py b/code/pytorchvideo/pytorchvideo/models/hub/efficient_x3d_mobile_cpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..792ede41ba72d46424538eae0ac934ceb9969b6f
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/efficient_x3d_mobile_cpu.py
@@ -0,0 +1,84 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Optional
+
+import torch.nn as nn
+from pytorchvideo.models.accelerator.mobile_cpu.efficient_x3d import create_x3d
+from torch.hub import load_state_dict_from_url
+
+
+_root_dir = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics"
+_checkpoint_paths = {
+ "efficient_x3d_xs": f"{_root_dir}/efficient_x3d_xs_original_form.pyth",
+ "efficient_x3d_s": f"{_root_dir}/efficient_x3d_s_original_form.pyth",
+}
+
+
+def _efficient_x3d(
+ pretrained: bool = False,
+ progress: bool = True,
+ checkpoint_path: Optional[str] = None,
+ # Model params
+ expansion: str = "XS",
+ **kwargs: Any,
+) -> nn.Module:
+
+ model = create_x3d(
+ expansion=expansion,
+ **kwargs,
+ )
+
+ if pretrained and checkpoint_path is not None:
+ # All models are loaded onto CPU by default
+ state_dict = load_state_dict_from_url(
+ checkpoint_path, progress=progress, map_location="cpu"
+ )
+ model.load_state_dict(state_dict)
+
+ return model
+
+
+def efficient_x3d_xs(pretrained: bool = False, progress: bool = True, **kwargs):
+ r"""
+ X3D-XS model architectures [1] with pretrained weights trained
+ on the Kinetics dataset with efficient implementation for mobile cpu.
+
+ [1] Christoph Feichtenhofer, "X3D: Expanding Architectures for
+ Efficient Video Recognition." https://arxiv.org/abs/2004.04730
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetcis-400 dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ To modify any other model settings, specify them in the kwargs.
+ All the args are defined in pytorchvideo/models/x3d.py
+ """
+ return _efficient_x3d(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=_checkpoint_paths["efficient_x3d_xs"],
+ expansion="XS",
+ **kwargs,
+ )
+
+
+def efficient_x3d_s(pretrained: bool = False, progress: bool = True, **kwargs):
+ r"""
+ X3D-S model architectures [1] with pretrained weights trained
+ on the Kinetics dataset with efficient implementation for mobile cpu.
+
+ [1] Christoph Feichtenhofer, "X3D: Expanding Architectures for
+ Efficient Video Recognition." https://arxiv.org/abs/2004.04730
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetcis-400 dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ To modify any other model settings, specify them in the kwargs.
+ All the args are defined in pytorchvideo/models/x3d.py
+ """
+ return _efficient_x3d(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=_checkpoint_paths["efficient_x3d_s"],
+ expansion="S",
+ **kwargs,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/r2plus1d.py b/code/pytorchvideo/pytorchvideo/models/hub/r2plus1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..4828db67810cc3e314a8f96c41b8bac23efb7699
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/r2plus1d.py
@@ -0,0 +1,54 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any
+
+import torch.nn as nn
+from pytorchvideo.models.r2plus1d import create_r2plus1d
+from torch.hub import load_state_dict_from_url
+
+
+"""
+R(2+1)D style models for video recognition.
+"""
+
+root_dir = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics"
+checkpoint_paths = {
+ "r2plus1d_r50": f"{root_dir}/R2PLUS1D_16x4_R50.pyth",
+}
+
+
+def r2plus1d_r50(
+ pretrained: bool = False, progress: bool = True, **kwargs: Any
+) -> nn.Module:
+ r"""
+
+ R(2+1)D model architecture from [1] with pretrained weights based on 16x4 setting
+ on the Kinetics dataset. Model with pretrained weights has top1 accuracy of 76.01.
+ (trained on 8*8 GPUs)
+
+ [1] "A closer look at spatiotemporal convolutions for action recognition"
+ Du Tran, Heng Wang, Lorenzo Torresani, Jamie Ray, Yann LeCun, Manohar Paluri. CVPR 2018.
+ https://arxiv.org/abs/1711.11248
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/resnet.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ model = create_r2plus1d(dropout_rate=0.5, **kwargs)
+
+ if pretrained:
+ path = checkpoint_paths["r2plus1d_r50"]
+ # All models are loaded onto CPU by default
+ checkpoint = load_state_dict_from_url(
+ path, progress=progress, map_location="cpu"
+ )
+ state_dict = checkpoint["model_state"]
+ model.load_state_dict(state_dict)
+
+ return model
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/resnet.py b/code/pytorchvideo/pytorchvideo/models/hub/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d33adfb6c640ffe16d99e5eb22d8180208ec40b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/resnet.py
@@ -0,0 +1,160 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Callable
+
+import torch.nn as nn
+from pytorchvideo.models.resnet import create_resnet, create_resnet_with_roi_head
+from torch.hub import load_state_dict_from_url
+
+
+"""
+ResNet style models for video recognition.
+"""
+
+root_dir = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo"
+checkpoint_paths = {
+ "slow_r50": f"{root_dir}/kinetics/SLOW_8x8_R50.pyth",
+ "slow_r50_detection": f"{root_dir}/ava/SLOW_4x16_R50_DETECTION.pyth",
+ "c2d_r50": f"{root_dir}/kinetics/C2D_8x8_R50.pyth",
+ "i3d_r50": f"{root_dir}/kinetics/I3D_8x8_R50.pyth",
+}
+
+
+def _resnet(
+ pretrained: bool = False,
+ progress: bool = True,
+ checkpoint_path: str = "",
+ model_builder: Callable = create_resnet,
+ **kwargs: Any,
+) -> nn.Module:
+ model = model_builder(**kwargs)
+ if pretrained:
+ # All models are loaded onto CPU by default
+ checkpoint = load_state_dict_from_url(
+ checkpoint_path, progress=progress, map_location="cpu"
+ )
+ state_dict = checkpoint["model_state"]
+ model.load_state_dict(state_dict)
+ return model
+
+
+def slow_r50(
+ pretrained: bool = False, progress: bool = True, **kwargs: Any
+) -> nn.Module:
+ r"""
+ Slow R50 model architecture [1] with pretrained weights based on 8x8 setting
+ on the Kinetics dataset. Model with pretrained weights has top1 accuracy of 74.58.
+
+ [1] "SlowFast Networks for Video Recognition"
+ Christoph Feichtenhofer et al
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/resnet.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _resnet(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["slow_r50"],
+ stem_conv_kernel_size=(1, 7, 7),
+ head_pool_kernel_size=(8, 7, 7),
+ model_depth=50,
+ **kwargs,
+ )
+
+
+def slow_r50_detection(
+ pretrained: bool = False, progress: bool = True, **kwargs: Any
+) -> nn.Module:
+ r"""
+ Slow R50 model architecture [1] with pretrained weights based on 4x16 setting.
+ The model is initially trained on Kinetics dataset for classification and later
+ finetuned on AVA dataset for detection.
+
+ [1] Christoph Feichtenhofer et al, "SlowFast Networks for Video Recognition"
+ https://arxiv.org/pdf/1812.03982.pdf
+ """
+ return _resnet(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["slow_r50_detection"],
+ model_builder=create_resnet_with_roi_head,
+ **kwargs,
+ )
+
+
+def c2d_r50(
+ pretrained: bool = False, progress: bool = True, **kwargs: Any
+) -> nn.Module:
+ r"""
+ C2D R50 model architecture with pretrained weights based on 8x8 setting
+ on the Kinetics dataset. Model with pretrained weights has top1 accuracy of 71.46.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/resnet.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _resnet(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["c2d_r50"],
+ stem_conv_kernel_size=(1, 7, 7),
+ stage1_pool=nn.MaxPool3d,
+ stage_conv_a_kernel_size=(
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ ),
+ **kwargs,
+ )
+
+
+def i3d_r50(
+ pretrained: bool = False, progress: bool = True, **kwargs: Any
+) -> nn.Module:
+ r"""
+ I3D R50 model architecture from [1] with pretrained weights based on 8x8 setting
+ on the Kinetics dataset. Model with pretrained weights has top1 accuracy of 73.27.
+
+ [1] "Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset"
+ Joao Carreira, Andrew Zisserman
+ https://arxiv.org/abs/1705.07750
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/resnet.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _resnet(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["i3d_r50"],
+ stem_conv_kernel_size=(5, 7, 7),
+ stage1_pool=nn.MaxPool3d,
+ stage_conv_a_kernel_size=(
+ (3, 1, 1),
+ [(3, 1, 1), (1, 1, 1)],
+ [(3, 1, 1), (1, 1, 1)],
+ [(1, 1, 1), (3, 1, 1)],
+ ),
+ **kwargs,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/slowfast.py b/code/pytorchvideo/pytorchvideo/models/hub/slowfast.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb987559b7593267b4f72ae187a2bfd1e9c1c28c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/slowfast.py
@@ -0,0 +1,179 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Callable
+
+import torch.nn as nn
+from pytorchvideo.models.slowfast import create_slowfast, create_slowfast_with_roi_head
+from torch.hub import load_state_dict_from_url
+
+
+root_dir = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo"
+checkpoint_paths = {
+ "slowfast_r50": f"{root_dir}/kinetics/SLOWFAST_8x8_R50.pyth",
+ "slowfast_r50_detection": f"{root_dir}/ava/SLOWFAST_8x8_R50_DETECTION.pyth",
+ "slowfast_r101": f"{root_dir}/kinetics/SLOWFAST_8x8_R101.pyth",
+ "slowfast_16x8_r101_50_50": f"{root_dir}/kinetics/SLOWFAST_16x8_R101_50_50.pyth",
+}
+
+
+def _slowfast(
+ pretrained: bool = False,
+ progress: bool = True,
+ checkpoint_path: str = "",
+ model_builder: Callable = create_slowfast,
+ **kwargs: Any,
+) -> nn.Module:
+ model = model_builder(**kwargs)
+ if pretrained:
+ # All models are loaded onto CPU by default
+ checkpoint = load_state_dict_from_url(
+ checkpoint_path, progress=progress, map_location="cpu"
+ )
+ state_dict = checkpoint["model_state"]
+ model.load_state_dict(state_dict)
+ return model
+
+
+def slowfast_r50(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs: Any,
+) -> nn.Module:
+ r"""
+ SlowFast R50 model architecture [1] trained with an 8x8 setting on the
+ Kinetics dataset. Model with pretrained weights has top1 accuracy of 76.4.
+
+ [1] Christoph Feichtenhofer et al, "SlowFast Networks for Video Recognition"
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/slowfast.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _slowfast(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["slowfast_r50"],
+ model_depth=50,
+ slowfast_fusion_conv_kernel_size=(7, 1, 1),
+ **kwargs,
+ )
+
+
+def slowfast_r101(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs: Any,
+) -> nn.Module:
+ r"""
+ SlowFast R101 model architecture [1] trained with an 8x8 setting on the
+ Kinetics dataset. Model with pretrained weights has top1 accuracy of 77.9.
+
+ [1] Christoph Feichtenhofer et al, "SlowFast Networks for Video Recognition"
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/slowfast.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _slowfast(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["slowfast_r101"],
+ model_depth=101,
+ slowfast_fusion_conv_kernel_size=(5, 1, 1),
+ **kwargs,
+ )
+
+
+def slowfast_16x8_r101_50_50(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs: Any,
+) -> nn.Module:
+ r"""
+ SlowFast R101_50_50 model architecture [1] trained with an 16x8 setting on the
+ Kinetics dataset. Model with pretrained weights has top1 accuracy of 78.7.
+
+ [1] Christoph Feichtenhofer et al, "SlowFast Networks for Video Recognition"
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/slowfast.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ # slowfast_r101_50_50 has 6 conv blocks with kernel=(3, 1, 1) in stage 4.
+ stage_conv_a_kernel_sizes = (
+ (
+ (1, 1, 1),
+ (1, 1, 1),
+ ((3, 1, 1),) * 6 + ((1, 1, 1),) * (23 - 6),
+ (3, 1, 1),
+ ),
+ (
+ (3, 1, 1),
+ (3, 1, 1),
+ ((3, 1, 1),) * 6 + ((1, 1, 1),) * (23 - 6),
+ (3, 1, 1),
+ ),
+ )
+ return _slowfast(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["slowfast_16x8_r101_50_50"],
+ model_depth=101,
+ slowfast_fusion_conv_kernel_size=(5, 1, 1),
+ stage_conv_a_kernel_sizes=stage_conv_a_kernel_sizes,
+ head_pool_kernel_sizes=((16, 7, 7), (64, 7, 7)),
+ **kwargs,
+ )
+
+
+def slowfast_r50_detection(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs: Any,
+) -> nn.Module:
+ r"""
+ SlowFast R50 model architecture [1] with pretrained weights based on 8x8 setting.
+ The model is initially trained on Kinetics dataset for classification and later
+ finetuned on AVA dataset for detection.
+
+ [1] Christoph Feichtenhofer et al, "SlowFast Networks for Video Recognition"
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/slowfast.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _slowfast(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["slowfast_r50_detection"],
+ model_builder=create_slowfast_with_roi_head,
+ **kwargs,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/utils.py b/code/pytorchvideo/pytorchvideo/models/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..de81f8fc1347617fff328c164a8bdad1e5c3417a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/utils.py
@@ -0,0 +1,45 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Callable, Dict, Optional
+
+import torch.nn as nn
+from torch.hub import load_state_dict_from_url
+
+
+MODEL_ZOO_ROOT_DIR = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo"
+
+
+def hub_model_builder(
+ model_builder_func: Callable,
+ pretrained: bool = False,
+ progress: bool = True,
+ checkpoint_path: str = "",
+ default_config: Optional[Dict[Any, Any]] = None,
+ **kwargs: Any,
+) -> nn.Module:
+ """
+ model_builder_func (Callable): Model builder function.
+ pretrained (bool): Whether to load a pretrained model or not. Default: False.
+ progress (bool): Whether or not to display a progress bar to stderr. Default: True.
+ checkpoint_path (str): URL of the model weight to download.
+ default_config (Dict): Default model configs that is passed to the model builder.
+ **kwargs: (Any): Additional model configs. Do not modify the model configuration
+ via the kwargs for pretrained model.
+ """
+ if pretrained:
+ assert len(kwargs) == 0, "Do not change kwargs for pretrained model."
+
+ if default_config is not None:
+ for argument, value in default_config.items():
+ if kwargs.get(argument) is None:
+ kwargs[argument] = value
+
+ model = model_builder_func(**kwargs)
+ if pretrained:
+ # All models are loaded onto CPU by default
+ checkpoint = load_state_dict_from_url(
+ checkpoint_path, progress=progress, map_location="cpu"
+ )
+ state_dict = checkpoint["model_state"]
+ model.load_state_dict(state_dict)
+ return model
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/vision_transformers.py b/code/pytorchvideo/pytorchvideo/models/hub/vision_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ae623e9cd15054d70f9ae4ba882107d1b097ed
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/vision_transformers.py
@@ -0,0 +1,158 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any
+
+import torch.nn as nn
+from pytorchvideo.models.vision_transformers import (
+ create_multiscale_vision_transformers,
+)
+
+from .utils import hub_model_builder, MODEL_ZOO_ROOT_DIR
+
+
+checkpoint_paths = {
+ "mvit_base_16x4": "{}/kinetics/MVIT_B_16x4.pyth".format(MODEL_ZOO_ROOT_DIR),
+ "mvit_base_32x3": "{}/kinetics/MVIT_B_32x3_f294077834.pyth".format(
+ MODEL_ZOO_ROOT_DIR
+ ),
+ "mvit_base_16": "{}/imagenet/MVIT_B_16_f292487636.pyth".format(MODEL_ZOO_ROOT_DIR),
+}
+
+mvit_video_base_config = {
+ "spatial_size": 224,
+ "temporal_size": 16,
+ "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
+ "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
+ "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
+ "pool_kv_stride_adaptive": [1, 8, 8],
+ "pool_kvq_kernel": [3, 3, 3],
+}
+
+mvit_video_base_32x3_config = {
+ "spatial_size": 224,
+ "temporal_size": 32,
+ "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
+ "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
+ "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
+ "pool_kv_stride_adaptive": [1, 8, 8],
+ "pool_kvq_kernel": [3, 3, 3],
+}
+
+mvit_image_base_16_config = {
+ "spatial_size": 224,
+ "temporal_size": 1,
+ "depth": 16,
+ "conv_patch_embed_kernel": [7, 7],
+ "conv_patch_embed_stride": [4, 4],
+ "conv_patch_embed_padding": [3, 3],
+ "use_2d_patch": True,
+ "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
+ "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
+ "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
+ "pool_kv_stride_adaptive": [1, 4, 4],
+ "pool_kvq_kernel": [1, 3, 3],
+}
+
+
+def mvit_base_16x4(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs: Any,
+) -> nn.Module:
+ """
+ Multiscale Vision Transformers model architecture [1] trained with an 16x4
+ setting on the Kinetics400 dataset. Model with pretrained weights has top1
+ accuracy of 78.9%.
+
+ [1] Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra
+ Malik, Christoph Feichtenhofer, "Multiscale Vision Transformers"
+ https://arxiv.org/pdf/2104.11227.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetics400 dataset.
+ progress (bool): If True, displays a progress bar of the download to stderr.
+ kwargs: Use these to modify any of the other model settings. All the
+ options are defined in create_multiscale_vision_transformers.
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+
+ return hub_model_builder(
+ model_builder_func=create_multiscale_vision_transformers,
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["mvit_base_16x4"],
+ default_config=mvit_video_base_config,
+ **kwargs,
+ )
+
+
+def mvit_base_32x3(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs: Any,
+) -> nn.Module:
+ """
+ Multiscale Vision Transformers model architecture [1] trained with an 32x3
+ setting on the Kinetics400 dataset. Model with pretrained weights has top1
+ accuracy of 80.3%.
+
+ [1] Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra
+ Malik, Christoph Feichtenhofer, "Multiscale Vision Transformers"
+ https://arxiv.org/pdf/2104.11227.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetics400 dataset.
+ progress (bool): If True, displays a progress bar of the download to stderr.
+ kwargs: Use these to modify any of the other model settings. All the
+ options are defined in create_multiscale_vision_transformers.
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+
+ return hub_model_builder(
+ model_builder_func=create_multiscale_vision_transformers,
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["mvit_base_32x3"],
+ default_config=mvit_video_base_32x3_config,
+ **kwargs,
+ )
+
+
+def mvit_base_16(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs: Any,
+) -> nn.Module:
+ """
+ Multiscale Vision Transformers model architecture [1] with a depth 16 trained on
+ ImageNet-1k dataset. Model with pretrained weights has top1 accuracy of 83%.
+
+ [1] Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra
+ Malik, Christoph Feichtenhofer, "Multiscale Vision Transformers"
+ https://arxiv.org/pdf/2104.11227.pdf
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Kinetics400 dataset.
+ progress (bool): If True, displays a progress bar of the download to stderr.
+ kwargs: Use these to modify any of the other model settings. All the
+ options are defined in create_multiscale_vision_transformers.
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+
+ return hub_model_builder(
+ model_builder_func=create_multiscale_vision_transformers,
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["mvit_base_16"],
+ default_config=mvit_image_base_16_config,
+ **kwargs,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/models/hub/x3d.py b/code/pytorchvideo/pytorchvideo/models/hub/x3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..2be4f8301abfbbfe85af6245730d54ba3ffe9d45
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/hub/x3d.py
@@ -0,0 +1,162 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Optional
+
+import torch.nn as nn
+from pytorchvideo.models.x3d import create_x3d
+from torch.hub import load_state_dict_from_url
+
+
+root_dir = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics"
+checkpoint_paths = {
+ "x3d_xs": f"{root_dir}/X3D_XS.pyth",
+ "x3d_s": f"{root_dir}/X3D_S.pyth",
+ "x3d_m": f"{root_dir}/X3D_M.pyth",
+ "x3d_l": f"{root_dir}/X3D_L.pyth",
+}
+
+
+def _x3d(
+ pretrained: bool = False,
+ progress: bool = True,
+ checkpoint_path: Optional[str] = None,
+ **kwargs: Any,
+) -> nn.Module:
+ model = create_x3d(**kwargs)
+ if pretrained and checkpoint_path is not None:
+ # All models are loaded onto CPU by default
+ checkpoint = load_state_dict_from_url(
+ checkpoint_path, progress=progress, map_location="cpu"
+ )
+ state_dict = checkpoint["model_state"]
+ model.load_state_dict(state_dict)
+ return model
+
+
+def x3d_xs(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs,
+):
+ r"""
+ X3D-XS model architecture [1] trained on the Kinetics dataset.
+ Model with pretrained weights has top1 accuracy of 69.12.
+
+ [1] Christoph Feichtenhofer, "X3D: Expanding Architectures for
+ Efficient Video Recognition." https://arxiv.org/abs/2004.04730
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/x3d.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _x3d(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["x3d_xs"],
+ input_clip_length=4,
+ input_crop_size=160,
+ **kwargs,
+ )
+
+
+def x3d_s(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs,
+):
+ """
+ X3D-XS model architecture [1] trained on the Kinetics dataset.
+ Model with pretrained weights has top1 accuracy of 73.33.
+
+ [1] Christoph Feichtenhofer, "X3D: Expanding Architectures for
+ Efficient Video Recognition." https://arxiv.org/abs/2004.04730
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/x3d.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _x3d(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["x3d_s"],
+ input_clip_length=13,
+ input_crop_size=160,
+ **kwargs,
+ )
+
+
+def x3d_m(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs,
+):
+ """
+ X3D-XS model architecture [1] trained on the Kinetics dataset.
+ Model with pretrained weights has top1 accuracy of 75.94.
+
+ [1] Christoph Feichtenhofer, "X3D: Expanding Architectures for
+ Efficient Video Recognition." https://arxiv.org/abs/2004.04730
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/x3d.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _x3d(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["x3d_m"],
+ input_clip_length=16,
+ input_crop_size=224,
+ **kwargs,
+ )
+
+
+def x3d_l(
+ pretrained: bool = False,
+ progress: bool = True,
+ **kwargs,
+):
+ """
+ X3D-XS model architecture [1] trained on the Kinetics dataset.
+ Model with pretrained weights has top1 accuracy of 77.44.
+
+ [1] Christoph Feichtenhofer, "X3D: Expanding Architectures for
+ Efficient Video Recognition." https://arxiv.org/abs/2004.04730
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on the Kinetics dataset
+ progress (bool): If True, displays a progress bar of the download to stderr
+ kwargs: use these to modify any of the other model settings. All the
+ options are defined in pytorchvideo/models/x3d.py
+
+ NOTE: to use the pretrained model, do not modify the model configuration
+ via the kwargs. Only modify settings via kwargs to initialize a new model
+ without pretrained weights.
+ """
+ return _x3d(
+ pretrained=pretrained,
+ progress=progress,
+ checkpoint_path=checkpoint_paths["x3d_l"],
+ input_clip_length=16,
+ input_crop_size=312,
+ depth_factor=5.0,
+ **kwargs,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/models/masked_multistream.py b/code/pytorchvideo/pytorchvideo/models/masked_multistream.py
new file mode 100644
index 0000000000000000000000000000000000000000..096d4b0518a05f24954cc7c0763ecc3286263087
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/masked_multistream.py
@@ -0,0 +1,384 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+from typing import List, Optional, Tuple
+
+import torch
+from pytorchvideo.layers.utils import set_attributes
+from torch import nn
+from torch.nn.utils.rnn import pack_padded_sequence
+
+
+"""
+This file contains nn.Modules that take a tensor and mask in their forward function.
+These masks can be used to represent invalid values (e.g. for tensors with varying
+temporal dimension size). To easily compose these modules together, a
+MaskedSequential module is provided.
+
+Example usage:
+
+ feature_dim = 64
+ input_stream = MaskedSequential(
+ PositionalEncoding(feature_dim),
+ Dropout(p=0.1),
+ TransposeMultiheadAttention(feature_dim),
+ MaskedTemporalPooling(feature_dim, method="avg"),
+ LayerNorm(feature_dim),
+ LearnMaskedDefault(feature_dim),
+ )
+
+ input_tensor = ... # tensor with shape (batch_size, seq_len, feature_dim)
+ mask_tensor = ... # bool tensor with shape (batch_size, seq_len)
+ result = input_stream(input=input_tensor, mask=mask_tensor)
+"""
+
+
+class MaskedTemporalPooling(torch.nn.Module):
+ """
+ Applies temporal pooling operations on masked inputs. For each pooling operation
+ all masked values are ignored.
+ """
+
+ def __init__(self, method: str):
+ """
+ method (str): the method of pooling to use. Options:
+ 'max': reduces temporal dimension to each valid max value.
+ 'avg': averages valid values in the temporal dimension.
+ 'sum': sums valid values in the temporal dimension.
+ Note if all batch row elements are invalid, the temporal dimension is
+ pooled to 0 values.
+ """
+ super().__init__()
+ assert method in ("max", "avg", "sum")
+ self._method = method
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): tensor with shape (batch_size, seq_len, feature_dim)
+ mask (torch.Tensor): bool tensor with shape (batch_size, seq_len).
+ Sequence elements that are False are invalid.
+
+ Returns:
+ Tensor with shape (batch_size, feature_dim)
+ """
+ assert x.dim() == 3, "Requires x shape (batch_size x seq_len x feature_dim)"
+ b, t = x.shape[0], x.shape[1]
+ if mask is None:
+ mask = torch.ones((b, t), dtype=torch.bool)
+
+ if self._method == "max":
+ x[~mask, :] = float("-inf")
+
+ # Invalid batch rows are set to 0.
+ invalid_first_dim = ~mask.view(b, -1).any(dim=-1)
+ x[invalid_first_dim, :] = 0
+
+ x = torch.max(x, dim=1)[0]
+ elif self._method == "avg":
+ x = x * mask.unsqueeze(-1).float()
+ mask = mask.view(b, t, -1).any(dim=-1)
+ valid_lengths = mask.float().sum(dim=-1).int()
+ x = x.sum(dim=1)
+ x = x.div(valid_lengths.clamp(min=1).unsqueeze(-1).expand(x.size()).float())
+ elif self._method == "sum": # sum
+ x = x * mask.unsqueeze(-1).float()
+ x = x.sum(dim=1)
+ else:
+ raise NotImplementedError(
+ f"{self._method} not available options are: 'max', 'avg', 'sum'"
+ )
+
+ return x
+
+
+class TransposeMultiheadAttention(nn.Module):
+ """
+ Wrapper for nn.MultiheadAttention which first transposes the input tensor
+ from (batch_size, seq_len, feature_dim) to (seq_length, batch_size, feature_dim),
+ then applies the attention and transposes the attention outputs back to the input
+ shape.
+ """
+
+ def __init__(self, feature_dim: int, num_heads: int = 1):
+ """
+ Args:
+ feature_dim (int): attention embedding dimension
+ num_heads (int): number of attention heads
+ """
+ super().__init__()
+ self._attention = nn.MultiheadAttention(
+ embed_dim=feature_dim, num_heads=num_heads
+ )
+ self._attention_weights = None
+
+ @property
+ def attention_weights(self) -> Optional[torch.Tensor]:
+ """
+ Contains attention weights from last forward call.
+ """
+ return self._attention_weights
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): tensor of shape (batch_size, seq_len, feature_dim)
+ mask (torch.Tensor): bool tensor with shape (batch_size, seq_len).
+ Sequence elements that are False are invalid.
+
+ Returns:
+ Tensor with shape (batch_size, seq_len, feature_dim)
+ """
+ assert x.dim() == 3, "Requires x shape (batch_size x seq_len x feature_dim)"
+
+ if mask is not None:
+ # At least the first element of each masked batch row must be valid for
+ # key_padding_mask.
+ mask[:, 0] = True
+ mask = ~mask
+
+ # Transpose x to (seq_length x batch_size x feature_dim).
+ x = x.transpose(0, 1)
+ attn_output, self._attention_weights = self._attention(
+ x, x, x, key_padding_mask=mask
+ )
+
+ # Transpose attention output to (batch_size x seq_length x feature_dim).
+ attn_output = attn_output.transpose(0, 1)
+ return attn_output
+
+
+class LearnMaskedDefault(nn.Module):
+ """
+ Learns default values to fill invalid entries within input tensors. The
+ invalid entries are represented by a mask which is passed into forward alongside
+ the input tensor. Note the default value is only used if all entries in the batch row are
+ invalid rather than just a portion of invalid entries within each batch row.
+ """
+
+ def __init__(
+ self, feature_dim: int, init_method: str = "gaussian", freeze: bool = False
+ ):
+ """
+ Args:
+ feature_dim (int): the size of the default value parameter, this must match the
+ input tensor size.
+ init_method (str): the initial default value parameter. Options:
+ 'guassian'
+ 'zeros'
+ freeze (bool): If True, the learned default parameter weights are frozen.
+ """
+ super().__init__()
+ if init_method == "zeros":
+ self._learned_defaults = nn.Parameter(
+ torch.zeros(feature_dim), requires_grad=(not freeze)
+ )
+ elif init_method == "gaussian":
+ self._learned_defaults = nn.Parameter(
+ torch.Tensor(feature_dim), requires_grad=(not freeze)
+ )
+ nn.init.normal_(self._learned_defaults)
+ else:
+ raise NotImplementedError(
+ f"{init_method} not available. Options are: 'zeros' or 'gaussian'"
+ )
+
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): tensor of shape (batch_size, feature_dim).
+ mask (torch.Tensor): bool tensor of shape (batch_size, seq_len) If all elements
+ in the batch dimension are False the learned default parameter is used for
+ that batch element.
+
+ Returns:
+ Tensor with shape (batch_size, feature_dim)
+ """
+ # Determine which rows have no valid entries and use these for the default value mask.
+ mask = mask.view(mask.shape[0], -1).any(dim=-1)
+ for i in range(1, x.dim()):
+ mask = mask.unsqueeze(i)
+ x = x * mask.float() + self._learned_defaults * (1 - mask.float())
+ return x
+
+
+class LSTM(nn.Module):
+ """
+ Wrapper for torch.nn.LSTM that handles masked inputs.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ hidden_dim: int,
+ dropout: float = 0.0,
+ bidirectional: bool = False,
+ ):
+ """
+ Args:
+ dim_in (int): input feature dimension
+ hidden_dim (int): hidden dimesion of lstm layer
+ dropout (float): dropout rate - 0.0 if no dropout
+ bidirectional (bool): bidirectional or forward only
+ """
+ super().__init__()
+ self.lstm = nn.LSTM(
+ dim_in,
+ hidden_dim,
+ batch_first=True,
+ dropout=dropout,
+ bidirectional=bidirectional,
+ )
+ self.lstm.flatten_parameters()
+ self.output_dim = 2 * hidden_dim if bidirectional else hidden_dim
+ self.bidirectional = bidirectional
+
+ def forward(
+ self, data: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ data (torch.Tensor): tensor with shape (batch_size, seq_len, feature_dim)
+ mask (torch.Tensor): bool tensor with shape (batch_size, seq_len).
+ Sequence elements that are False are invalid.
+
+ Returns:
+ Tensor with shape (batch_size, output_dim) - outoput_dim is determined by
+ hidden_dim and whether bidirectional or not
+ """
+ assert data.dim() == 3
+ b, t = data.shape[0], data.shape[1]
+
+ if mask is None:
+ mask = torch.ones((b, t), dtype=torch.bool)
+
+ lengths = mask.sum(axis=1)
+ x_packed = pack_padded_sequence(
+ data,
+ lengths.clamp(1, data.size(1)),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+ _, (h, _) = self.lstm(x_packed)
+
+ if self.bidirectional:
+ out = torch.cat([h[0, :, :], h[1, :, :]], dim=-1)
+ else:
+ out = h[-1, :, :]
+
+ return out
+
+
+class TransposeTransformerEncoder(nn.Module):
+ """
+ Wrapper for torch.nn.TransformerEncoder that handles masked inputs.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ num_heads: int = 1,
+ num_layers: int = 1,
+ ):
+ """
+ Args:
+ dim_in (int): input feature dimension
+ num_heads (int): number of heads in the nn.MultiHeadAttention layers
+ num_layers (int): the number of sub-encoder-layers in the encoder
+ """
+ super().__init__()
+ self.encoder = nn.TransformerEncoder(
+ nn.TransformerEncoderLayer(dim_in, num_heads), num_layers
+ )
+
+ def forward(
+ self, data: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Args:
+ data (torch.Tensor): tensor with shape (batch_size, seq_len, feature_dim)
+ mask (torch.Tensor): bool tensor with shape (batch_size, seq_len).
+ Sequence elements that are False are invalid.
+
+ Returns:
+ Tensor with shape (batch_size, feature_dim)
+ """
+ if mask is not None:
+ # At least the first element of each masked batch row must be valid for
+ # key_padding_mask.
+ mask[:, 0] = True
+ mask = ~mask
+
+ out = self.encoder(
+ src=data.transpose(0, 1), src_key_padding_mask=mask
+ ).transpose(0, 1)
+
+ return out[:, 0, :]
+
+
+class MaskedSequential(nn.Sequential):
+ """
+ A sequential container that overrides forward to take a mask as well as the usual
+ input tensor. This mask is only applied to modules in _MASK_MODULES (which take
+ the mask argument).
+ """
+
+ _MASK_MODULES = [
+ MaskedTemporalPooling,
+ LearnMaskedDefault,
+ TransposeMultiheadAttention,
+ LSTM,
+ TransposeTransformerEncoder,
+ ]
+
+ def forward(self, input: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ for module in self:
+ if any(isinstance(module, mask_type) for mask_type in self._MASK_MODULES):
+ input = module(input, mask=mask)
+ else:
+ input = module(input)
+
+ return input
+
+
+class MaskedMultiPathWay(nn.Module):
+ """
+ Masked multi-pathway is composed of a list of stream nn.Modules followed by a
+ fusion nn.Module that reduces these streams. Each stream module takes a mask
+ and input tensor.
+
+ ::
+
+ Pathway 1 ... Pathway N
+ ↓ ↓
+ Block 1 Block N
+ ↓⭠ --Fusion----↓
+ """
+
+ def __init__(
+ self,
+ *,
+ multipathway_blocks: nn.ModuleList,
+ multipathway_fusion: Optional[nn.Module],
+ ) -> None:
+ """
+ Args:
+ multipathway_blocks (nn.module_list): list of models from all pathways.
+ multipathway_fusion (nn.module): fusion model.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+
+ def forward(
+ self, x_and_mask: List[Tuple[torch.Tensor, torch.Tensor]]
+ ) -> torch.Tensor:
+ out = []
+ for pathway_idx in range(len(self.multipathway_blocks)):
+ out.append(self.multipathway_blocks[pathway_idx](*x_and_mask[pathway_idx]))
+
+ if self.multipathway_fusion is not None:
+ x = self.multipathway_fusion(out)
+ return x
diff --git a/code/pytorchvideo/pytorchvideo/models/memory_bank.py b/code/pytorchvideo/pytorchvideo/models/memory_bank.py
new file mode 100644
index 0000000000000000000000000000000000000000..f724c2d13ac7970d0010056bcfbce749495e3f07
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/memory_bank.py
@@ -0,0 +1,113 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pytorchvideo.layers.utils import set_attributes
+
+
+class MemoryBank(nn.Module):
+ """
+ Performs Non-Parametric Instance Discrimination for self supervised learning on
+ video. A memory bank is built to keep and update the historical feature embedding
+ and use them for contrastive learning.
+
+ The original paper is:
+ Unsupervised Feature Learning via Non-Parametric Instance Discrimination
+ https://arxiv.org/pdf/1805.01978.pdf
+
+ More details can be found from the memory bank part in the following paper:
+ Momentum Contrast for Unsupervised Visual Representation Learning
+ https://arxiv.org/pdf/1911.05722.pdf
+ """
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ mlp: Optional[nn.Module] = None,
+ neg_size: int = 4096,
+ temperature: float = 0.07,
+ bank_size: int = 1280000,
+ dim: int = 2048,
+ mmt: float = 0.999,
+ ) -> None:
+ """
+ Args:
+ backbone (nn.Module): backbone used to forward the input.
+ mlp (nn.Module): multi-layer perception used in memory bank instance
+ discrimination model.
+ neg_size (int): size of negative samples per instance.
+ temperature (float): temperature to use for contrastive learning.
+ bank_size (int): size of the memory bank, expected to be the same size as
+ the training set.
+ dim (int): dimension of the channel.
+ mmt (float): momentum to use.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ self._init_mem_bank(bank_size, dim)
+
+ def _init_mem_bank(self, bank_size: int, dim: int) -> None:
+ """
+ Given the memory bank size and the channel dimension, initialize the memory
+ bank.
+ Args:
+ bank_size (int): size of the memory bank, expected to be the same size as
+ the training set.
+ dim (int): dimension of the channel.
+ """
+ stdv = 1.0 / math.sqrt(dim / 3)
+ self.register_buffer(
+ "memory",
+ torch.rand(
+ bank_size,
+ dim,
+ )
+ .mul_(2 * stdv)
+ .add_(-stdv)
+ .to(next(self.backbone.parameters()).device),
+ )
+
+ def forward(self, x: torch.Tensor, x_ind: torch.Tensor) -> torch.Tensor:
+ """
+ Perform contrastive learning with random sampled negative instance from the
+ memory bank. During training, update the memory bank with latest feature
+ embedding.
+ Args:
+ x (torch.tensor): a batch of image with augmentation. The input tensor
+ shape should able to be feed into the backbone.
+ x_ind (torch.tensor): the index of the image x from the dataset. Expected
+ shape is B.
+ """
+ batch_size = x.shape[0]
+ x = self.backbone(x)
+ if self.mlp is not None:
+ x = self.mlp(x)
+ # Normalize the output embedding before multiplication.
+ x = F.normalize(x, p=2, dim=1)
+ # Random sample negative instances from the memory bank.
+ idx = torch.randint(0, self.bank_size, size=(batch_size, self.neg_size + 1)).to(
+ x.device
+ )
+ # Fill the first with positive instances.
+ idx.select(1, 0).copy_(x_ind.data)
+ weight = torch.index_select(self.memory, 0, idx.view(-1)).detach()
+ weight = weight.view(batch_size, self.neg_size + 1, self.dim)
+ # Multiplication for contrastive learning.
+ out = torch.einsum("bkc,bc->bk", weight, x)
+ out = torch.div(out, self.temperature)
+ gt = torch.zeros((batch_size,), device=x.device, dtype=torch.long)
+ loss = torch.nn.functional.cross_entropy(out, gt)
+ # Update memory during training.
+ if self.training:
+ with torch.no_grad():
+ pos = torch.index_select(self.memory, 0, x_ind.view(-1))
+ pos.mul_(self.mmt)
+ pos.add_(torch.mul(x, 1 - self.mmt))
+ norm = pos.pow(2).sum(1, keepdim=True).pow(0.5)
+ updated = pos.div(norm)
+ self.memory.index_copy_(0, x_ind, updated)
+ return loss
diff --git a/code/pytorchvideo/pytorchvideo/models/net.py b/code/pytorchvideo/pytorchvideo/models/net.py
new file mode 100644
index 0000000000000000000000000000000000000000..2411f460a89f27b37f8b74d6e128ff570e13c07a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/net.py
@@ -0,0 +1,122 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.utils import set_attributes
+from pytorchvideo.models.weight_init import init_net_weights
+
+
+class Net(nn.Module):
+ """
+ Build a general Net models with a list of blocks for video recognition.
+
+ ::
+
+ Input
+ ↓
+ Block 1
+ ↓
+ .
+ .
+ .
+ ↓
+ Block N
+ ↓
+
+ The ResNet builder can be found in `create_resnet`.
+ """
+
+ def __init__(self, *, blocks: nn.ModuleList) -> None:
+ """
+ Args:
+ blocks (torch.nn.module_list): the list of block modules.
+ """
+ super().__init__()
+ assert blocks is not None
+ self.blocks = blocks
+ init_net_weights(self)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for _, block in enumerate(self.blocks):
+ x = block(x)
+ return x
+
+
+class DetectionBBoxNetwork(nn.Module):
+ """
+ A general purpose model that handles bounding boxes as part of input.
+ """
+
+ def __init__(self, model: nn.Module, detection_head: nn.Module):
+ """
+ Args:
+ model (nn.Module): a model that preceeds the head. Ex: stem + stages.
+ detection_head (nn.Module): a network head. that can take in input bounding boxes
+ and the outputs from the model.
+ """
+ super().__init__()
+ self.model = model
+ self.detection_head = detection_head
+
+ def forward(self, x: torch.Tensor, bboxes: torch.Tensor):
+ """
+ Args:
+ x (torch.tensor): input tensor
+ bboxes (torch.tensor): accociated bounding boxes.
+ The format is N*5 (Index, X_1,Y_1,X_2,Y_2) if using RoIAlign
+ and N*6 (Index, x_ctr, y_ctr, width, height, angle_degrees) if
+ using RoIAlignRotated.
+ """
+ features = self.model(x)
+ out = self.detection_head(features, bboxes)
+ return out.view(out.shape[0], -1)
+
+
+class MultiPathWayWithFuse(nn.Module):
+ """
+ Build multi-pathway block with fusion for video recognition, each of the pathway
+ contains its own Blocks and Fusion layers across different pathways.
+
+ ::
+
+ Pathway 1 ... Pathway N
+ ↓ ↓
+ Block 1 Block N
+ ↓⭠ --Fusion----↓
+ """
+
+ def __init__(
+ self,
+ *,
+ multipathway_blocks: nn.ModuleList,
+ multipathway_fusion: Optional[nn.Module],
+ inplace: Optional[bool] = True,
+ ) -> None:
+ """
+ Args:
+ multipathway_blocks (nn.module_list): list of models from all pathways.
+ multipathway_fusion (nn.module): fusion model.
+ inplace (bool): If inplace, directly update the input list without making
+ a copy.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
+ assert isinstance(
+ x, list
+ ), "input for MultiPathWayWithFuse needs to be a list of tensors"
+ if self.inplace:
+ x_out = x
+ else:
+ x_out = [None] * len(x)
+ for pathway_idx in range(len(self.multipathway_blocks)):
+ if self.multipathway_blocks[pathway_idx] is not None:
+ x_out[pathway_idx] = self.multipathway_blocks[pathway_idx](
+ x[pathway_idx]
+ )
+ if self.multipathway_fusion is not None:
+ x_out = self.multipathway_fusion(x_out)
+ return x_out
diff --git a/code/pytorchvideo/pytorchvideo/models/r2plus1d.py b/code/pytorchvideo/pytorchvideo/models/r2plus1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..5744440696a58084769e6824b0e4b5671a67dc14
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/r2plus1d.py
@@ -0,0 +1,313 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+from functools import partial
+from typing import Callable, Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.convolutions import create_conv_2plus1d
+from pytorchvideo.models.head import create_res_basic_head
+from pytorchvideo.models.net import Net
+from pytorchvideo.models.resnet import create_bottleneck_block, create_res_stage
+from pytorchvideo.models.stem import create_res_basic_stem
+
+
+def create_2plus1d_bottleneck_block(
+ *,
+ # Convolution configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ conv_a_kernel_size: Tuple[int] = (1, 1, 1),
+ conv_a_stride: Tuple[int] = (1, 1, 1),
+ conv_a_padding: Tuple[int] = (0, 0, 0),
+ conv_a: Callable = nn.Conv3d,
+ conv_b_kernel_size: Tuple[int] = (3, 3, 3),
+ conv_b_stride: Tuple[int] = (2, 2, 2),
+ conv_b_padding: Tuple[int] = (1, 1, 1),
+ conv_b_num_groups: int = 1,
+ conv_b_dilation: Tuple[int] = (1, 1, 1),
+ conv_b: Callable = create_conv_2plus1d,
+ conv_c: Callable = nn.Conv3d,
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ 2plus1d bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
+ and Activations repeated in the following order:
+
+ ::
+
+ Conv3d (conv_a)
+ ↓
+ Normalization (norm_a)
+ ↓
+ Activation (act_a)
+ ↓
+ Conv(2+1)d (conv_b)
+ ↓
+ Normalization (norm_b)
+ ↓
+ Activation (act_b)
+ ↓
+ Conv3d (conv_c)
+ ↓
+ Normalization (norm_c)
+
+ Normalization examples include: BatchNorm3d and None (no normalization).
+ Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
+
+ Args:
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ conv_a_stride (tuple): convolutional stride size(s) for conv_a.
+ conv_a_padding (tuple): convolutional padding(s) for conv_a.
+ conv_a (callable): a callable that constructs the conv_a conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_b_stride (tuple): convolutional stride size(s) for conv_b.
+ conv_b_padding (tuple): convolutional padding(s) for conv_b.
+ conv_b_num_groups (int): number of groups for groupwise convolution for
+ conv_b.
+ conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ conv_b (callable): a callable that constructs the conv_b conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_c (callable): a callable that constructs the conv_c conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): 2plus1d bottleneck block.
+ """
+ return create_bottleneck_block(
+ dim_in=dim_in,
+ dim_inner=dim_inner,
+ dim_out=dim_out,
+ conv_a_kernel_size=conv_a_kernel_size,
+ conv_a_stride=conv_a_stride,
+ conv_a_padding=conv_a_padding,
+ conv_a=conv_a,
+ conv_b_kernel_size=conv_b_kernel_size,
+ conv_b_stride=conv_b_stride,
+ conv_b_padding=conv_b_padding,
+ conv_b_num_groups=conv_b_num_groups,
+ conv_b_dilation=conv_b_dilation,
+ conv_b=partial(
+ create_conv_2plus1d,
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ activation=activation,
+ ),
+ conv_c=conv_c,
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ activation=activation,
+ )
+
+
+def create_r2plus1d(
+ *,
+ # Input clip configs.
+ input_channel: int = 3,
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 400,
+ dropout_rate: float = 0.0,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_dim_out: int = 64,
+ stem_conv_kernel_size: Tuple[int] = (1, 7, 7),
+ stem_conv_stride: Tuple[int] = (1, 2, 2),
+ # Stage configs.
+ stage_conv_a_kernel_size: Tuple[Tuple[int]] = (
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ ),
+ stage_conv_b_kernel_size: Tuple[Tuple[int]] = (
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ ),
+ stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1),
+ stage_conv_b_dilation: Tuple[Tuple[int]] = (
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ ),
+ stage_spatial_stride: Tuple[int] = (2, 2, 2, 2),
+ stage_temporal_stride: Tuple[int] = (1, 1, 2, 2),
+ stage_bottleneck: Tuple[Callable] = (
+ create_2plus1d_bottleneck_block,
+ create_2plus1d_bottleneck_block,
+ create_2plus1d_bottleneck_block,
+ create_2plus1d_bottleneck_block,
+ ),
+ # Head configs.
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_size: Tuple[int] = (4, 7, 7),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = nn.Softmax,
+ head_output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Build the R(2+1)D network from::
+ A closer look at spatiotemporal convolutions for action recognition.
+ Du Tran, Heng Wang, Lorenzo Torresani, Jamie Ray, Yann LeCun, Manohar Paluri. CVPR 2018.
+
+ R(2+1)D follows the ResNet style architecture including three parts: Stem,
+ Stages and Head. The three parts are assembled in the following order:
+
+ ::
+
+ Input
+ ↓
+ Stem
+ ↓
+ Stage 1
+ ↓
+ .
+ .
+ .
+ ↓
+ Stage N
+ ↓
+ Head
+
+ Args:
+
+ input_channel (int): number of channels for the input video clip.
+
+ model_depth (int): the depth of the resnet.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+ norm (callable): a callable that constructs normalization layer.
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_dim_out (int): output channel size for stem.
+ stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
+ stem_conv_stride (tuple): convolutional stride size(s) of stem.
+
+ stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
+ for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
+ stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ stage_spatial_stride (tuple): the spatial stride for each stage.
+ stage_temporal_stride (tuple): the temporal stride for each stage.
+ stage_bottleneck (tuple): a callable that constructs bottleneck block layer
+ for each stage. Examples include: create_bottleneck_block,
+ create_2plus1d_bottleneck_block.
+
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_pool_kernel_size (tuple): the pooling kernel size.
+ head_output_size (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+
+ Returns:
+ (nn.Module): basic resnet.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_r2plus1d")
+
+ # Number of blocks for different stages given the model depth.
+ _MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
+
+ # Given a model depth, get the number of blocks for each stage.
+ assert (
+ model_depth in _MODEL_STAGE_DEPTH.keys()
+ ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
+ stage_depths = _MODEL_STAGE_DEPTH[model_depth]
+
+ blocks = []
+ # Create stem for R(2+1)D.
+ stem = create_res_basic_stem(
+ in_channels=input_channel,
+ out_channels=stem_dim_out,
+ conv_kernel_size=stem_conv_kernel_size,
+ conv_stride=stem_conv_stride,
+ conv_padding=[size // 2 for size in stem_conv_kernel_size],
+ pool=None,
+ norm=norm,
+ activation=activation,
+ )
+ blocks.append(stem)
+
+ stage_dim_in = stem_dim_out
+ stage_dim_out = stage_dim_in * 4
+
+ # Create each stage for R(2+1)D.
+ for idx in range(len(stage_depths)):
+ stage_dim_inner = stage_dim_out // 4
+ depth = stage_depths[idx]
+
+ stage_conv_b_stride = (
+ stage_temporal_stride[idx],
+ stage_spatial_stride[idx],
+ stage_spatial_stride[idx],
+ )
+
+ stage = create_res_stage(
+ depth=depth,
+ dim_in=stage_dim_in,
+ dim_inner=stage_dim_inner,
+ dim_out=stage_dim_out,
+ bottleneck=stage_bottleneck[idx],
+ conv_a_kernel_size=stage_conv_a_kernel_size[idx],
+ conv_a_stride=[1, 1, 1],
+ conv_a_padding=[size // 2 for size in stage_conv_a_kernel_size[idx]],
+ conv_b_kernel_size=stage_conv_b_kernel_size[idx],
+ conv_b_stride=stage_conv_b_stride,
+ conv_b_padding=[size // 2 for size in stage_conv_b_kernel_size[idx]],
+ conv_b_num_groups=stage_conv_b_num_groups[idx],
+ conv_b_dilation=stage_conv_b_dilation[idx],
+ norm=norm,
+ activation=activation,
+ )
+
+ blocks.append(stage)
+ stage_dim_in = stage_dim_out
+ stage_dim_out = stage_dim_out * 2
+
+ # Create head for R(2+1)D.
+ head = create_res_basic_head(
+ in_features=stage_dim_in,
+ out_features=model_num_class,
+ pool=head_pool,
+ output_size=head_output_size,
+ pool_kernel_size=head_pool_kernel_size,
+ dropout_rate=dropout_rate,
+ activation=head_activation,
+ output_with_global_average=head_output_with_global_average,
+ )
+ blocks.append(head)
+ return Net(blocks=nn.ModuleList(blocks))
diff --git a/code/pytorchvideo/pytorchvideo/models/resnet.py b/code/pytorchvideo/pytorchvideo/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1d6ca56188b3d6dc921d07118e4a8f2c55ad673
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/resnet.py
@@ -0,0 +1,1394 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, List, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.utils import set_attributes
+from pytorchvideo.models.head import create_res_basic_head, create_res_roi_pooling_head
+from pytorchvideo.models.net import DetectionBBoxNetwork, Net
+from pytorchvideo.models.stem import (
+ create_acoustic_res_basic_stem,
+ create_res_basic_stem,
+)
+
+
+def create_bottleneck_block(
+ *,
+ # Convolution configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ conv_a_kernel_size: Tuple[int] = (3, 1, 1),
+ conv_a_stride: Tuple[int] = (2, 1, 1),
+ conv_a_padding: Tuple[int] = (1, 0, 0),
+ conv_a: Callable = nn.Conv3d,
+ conv_b_kernel_size: Tuple[int] = (1, 3, 3),
+ conv_b_stride: Tuple[int] = (1, 2, 2),
+ conv_b_padding: Tuple[int] = (0, 1, 1),
+ conv_b_num_groups: int = 1,
+ conv_b_dilation: Tuple[int] = (1, 1, 1),
+ conv_b: Callable = nn.Conv3d,
+ conv_c: Callable = nn.Conv3d,
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
+ and Activations repeated in the following order:
+
+ ::
+
+ Conv3d (conv_a)
+ ↓
+ Normalization (norm_a)
+ ↓
+ Activation (act_a)
+ ↓
+ Conv3d (conv_b)
+ ↓
+ Normalization (norm_b)
+ ↓
+ Activation (act_b)
+ ↓
+ Conv3d (conv_c)
+ ↓
+ Normalization (norm_c)
+
+ Normalization examples include: BatchNorm3d and None (no normalization).
+ Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
+
+ Args:
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ conv_a_stride (tuple): convolutional stride size(s) for conv_a.
+ conv_a_padding (tuple): convolutional padding(s) for conv_a.
+ conv_a (callable): a callable that constructs the conv_a conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_b_stride (tuple): convolutional stride size(s) for conv_b.
+ conv_b_padding (tuple): convolutional padding(s) for conv_b.
+ conv_b_num_groups (int): number of groups for groupwise convolution for
+ conv_b.
+ conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ conv_b (callable): a callable that constructs the conv_b conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_c (callable): a callable that constructs the conv_c conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): resnet bottleneck block.
+ """
+ conv_a = conv_a(
+ in_channels=dim_in,
+ out_channels=dim_inner,
+ kernel_size=conv_a_kernel_size,
+ stride=conv_a_stride,
+ padding=conv_a_padding,
+ bias=False,
+ )
+ norm_a = (
+ None
+ if norm is None
+ else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ )
+ act_a = None if activation is None else activation()
+
+ conv_b = conv_b(
+ in_channels=dim_inner,
+ out_channels=dim_inner,
+ kernel_size=conv_b_kernel_size,
+ stride=conv_b_stride,
+ padding=conv_b_padding,
+ bias=False,
+ groups=conv_b_num_groups,
+ dilation=conv_b_dilation,
+ )
+ norm_b = (
+ None
+ if norm is None
+ else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ )
+ act_b = None if activation is None else activation()
+
+ conv_c = conv_c(
+ in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
+ )
+ norm_c = (
+ None
+ if norm is None
+ else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
+ )
+
+ return BottleneckBlock(
+ conv_a=conv_a,
+ norm_a=norm_a,
+ act_a=act_a,
+ conv_b=conv_b,
+ norm_b=norm_b,
+ act_b=act_b,
+ conv_c=conv_c,
+ norm_c=norm_c,
+ )
+
+
+def create_acoustic_bottleneck_block(
+ *,
+ # Convolution configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ conv_a_kernel_size: Tuple[int] = (3, 1, 1),
+ conv_a_stride: Tuple[int] = (2, 1, 1),
+ conv_a_padding: Tuple[int] = (1, 0, 0),
+ conv_a: Callable = nn.Conv3d,
+ # Conv b f configs.
+ conv_b_kernel_size: Tuple[int] = (1, 1, 1),
+ conv_b_stride: Tuple[int] = (1, 1, 1),
+ conv_b_padding: Tuple[int] = (0, 0, 0),
+ conv_b_num_groups: int = 1,
+ conv_b_dilation: Tuple[int] = (1, 1, 1),
+ conv_b: Callable = nn.Conv3d,
+ conv_c: Callable = nn.Conv3d,
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Acoustic Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
+ and Activations repeated in the following order:
+
+ ::
+
+ Conv3d (conv_a)
+ ↓
+ Normalization (norm_a)
+ ↓
+ Activation (act_a)
+ ↓
+ ---------------------------------
+ ↓ ↓
+ Temporal Conv3d (conv_b) Spatial Conv3d (conv_b)
+ ↓ ↓
+ Normalization (norm_b) Normalization (norm_b)
+ ↓ ↓
+ Activation (act_b) Activation (act_b)
+ ↓ ↓
+ ---------------------------------
+ ↓
+ Conv3d (conv_c)
+ ↓
+ Normalization (norm_c)
+
+ Normalization examples include: BatchNorm3d and None (no normalization).
+ Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
+
+ Args:
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ conv_a_stride (tuple): convolutional stride size(s) for conv_a.
+ conv_a_padding (tuple): convolutional padding(s) for conv_a.
+ conv_a (callable): a callable that constructs the conv_a conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_b_stride (tuple): convolutional stride size(s) for conv_b.
+ conv_b_padding (tuple): convolutional padding(s) for conv_b.
+ conv_b_num_groups (int): number of groups for groupwise convolution for
+ conv_b.
+ conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ conv_b (callable): a callable that constructs the conv_b conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_c (callable): a callable that constructs the conv_c conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): resnet acoustic bottleneck block.
+ """
+ conv_a = conv_a(
+ in_channels=dim_in,
+ out_channels=dim_inner,
+ kernel_size=conv_a_kernel_size,
+ stride=conv_a_stride,
+ padding=conv_a_padding,
+ bias=False,
+ )
+ norm_a = (
+ None
+ if norm is None
+ else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ )
+ act_a = None if activation is None else activation()
+
+ conv_b_1_kernel_size = [conv_b_kernel_size[0], 1, 1]
+ conv_b_1_stride = conv_b_stride
+ conv_b_1_padding = [conv_b_padding[0], 0, 0]
+
+ conv_b_2_kernel_size = [1, conv_b_kernel_size[1], conv_b_kernel_size[2]]
+ conv_b_2_stride = conv_b_stride
+ conv_b_2_padding = [0, conv_b_padding[1], conv_b_padding[2]]
+
+ conv_b_1_num_groups, conv_b_2_num_groups = (conv_b_num_groups,) * 2
+ conv_b_1_dilation = [conv_b_dilation[0], 1, 1]
+ conv_b_2_dilation = [1, conv_b_dilation[1], conv_b_dilation[2]]
+
+ conv_b_1 = conv_b(
+ in_channels=dim_inner,
+ out_channels=dim_inner,
+ kernel_size=conv_b_1_kernel_size,
+ stride=conv_b_1_stride,
+ padding=conv_b_1_padding,
+ bias=False,
+ groups=conv_b_1_num_groups,
+ dilation=conv_b_1_dilation,
+ )
+ norm_b_1 = (
+ None
+ if norm is None
+ else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ )
+ act_b_1 = None if activation is None else activation()
+
+ conv_b_2 = conv_b(
+ in_channels=dim_inner,
+ out_channels=dim_inner,
+ kernel_size=conv_b_2_kernel_size,
+ stride=conv_b_2_stride,
+ padding=conv_b_2_padding,
+ bias=False,
+ groups=conv_b_2_num_groups,
+ dilation=conv_b_2_dilation,
+ )
+ norm_b_2 = (
+ None
+ if norm is None
+ else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ )
+ act_b_2 = None if activation is None else activation()
+
+ conv_c = conv_c(
+ in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
+ )
+ norm_c = (
+ None
+ if norm is None
+ else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
+ )
+
+ return SeparableBottleneckBlock(
+ conv_a=conv_a,
+ norm_a=norm_a,
+ act_a=act_a,
+ conv_b=nn.ModuleList([conv_b_2, conv_b_1]),
+ norm_b=nn.ModuleList([norm_b_2, norm_b_1]),
+ act_b=nn.ModuleList([act_b_2, act_b_1]),
+ conv_c=conv_c,
+ norm_c=norm_c,
+ )
+
+
+def _trivial_sum(x, y):
+ """
+ Utility function used in lieu of lamda which are not picklable
+ """
+ return x + y
+
+
+def create_res_block(
+ *,
+ # Bottleneck Block configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ bottleneck: Callable,
+ use_shortcut: bool = False,
+ branch_fusion: Callable = _trivial_sum,
+ # Conv configs.
+ conv_a_kernel_size: Tuple[int] = (3, 1, 1),
+ conv_a_stride: Tuple[int] = (2, 1, 1),
+ conv_a_padding: Tuple[int] = (1, 0, 0),
+ conv_a: Callable = nn.Conv3d,
+ conv_b_kernel_size: Tuple[int] = (1, 3, 3),
+ conv_b_stride: Tuple[int] = (1, 2, 2),
+ conv_b_padding: Tuple[int] = (0, 1, 1),
+ conv_b_num_groups: int = 1,
+ conv_b_dilation: Tuple[int] = (1, 1, 1),
+ conv_b: Callable = nn.Conv3d,
+ conv_c: Callable = nn.Conv3d,
+ conv_skip: Callable = nn.Conv3d,
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation_bottleneck: Callable = nn.ReLU,
+ activation_block: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Residual block. Performs a summation between an identity shortcut in branch1 and a
+ main block in branch2. When the input and output dimensions are different, a
+ convolution followed by a normalization will be performed.
+
+ ::
+
+
+ Input
+ |-------+
+ ↓ |
+ Block |
+ ↓ |
+ Summation ←-+
+ ↓
+ Activation
+
+ Normalization examples include: BatchNorm3d and None (no normalization).
+ Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
+ Transform examples include: BottleneckBlock.
+
+ Args:
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ bottleneck (callable): a callable that constructs bottleneck block layer.
+ Examples include: create_bottleneck_block.
+ use_shortcut (bool): If true, use conv and norm layers in skip connection.
+ branch_fusion (callable): a callable that constructs summation layer.
+ Examples include: lambda x, y: x + y, OctaveSum.
+
+ conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ conv_a_stride (tuple): convolutional stride size(s) for conv_a.
+ conv_a_padding (tuple): convolutional padding(s) for conv_a.
+ conv_a (callable): a callable that constructs the conv_a conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_b_stride (tuple): convolutional stride size(s) for conv_b.
+ conv_b_padding (tuple): convolutional padding(s) for conv_b.
+ conv_b_num_groups (int): number of groups for groupwise convolution for
+ conv_b.
+ conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ conv_b (callable): a callable that constructs the conv_b conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_c (callable): a callable that constructs the conv_c conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_skip (callable): a callable that constructs the conv_skip conv layer,
+ examples include nn.Conv3d, OctaveConv, etc
+
+ norm (callable): a callable that constructs normalization layer. Examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation_bottleneck (callable): a callable that constructs activation layer in
+ bottleneck. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None
+ (not performing activation).
+ activation_block (callable): a callable that constructs activation layer used
+ at the end of the block. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid,
+ and None (not performing activation).
+
+ Returns:
+ (nn.Module): resnet basic block layer.
+ """
+ branch1_conv_stride = tuple([x * y for x, y in zip(conv_a_stride, conv_b_stride)])
+ norm_model = None
+ if use_shortcut or (
+ norm is not None and (dim_in != dim_out or np.prod(branch1_conv_stride) != 1)
+ ):
+ norm_model = norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
+
+ return ResBlock(
+ branch1_conv=conv_skip(
+ dim_in,
+ dim_out,
+ kernel_size=(1, 1, 1),
+ stride=branch1_conv_stride,
+ bias=False,
+ )
+ if (dim_in != dim_out or np.prod(branch1_conv_stride) != 1) or use_shortcut
+ else None,
+ branch1_norm=norm_model,
+ branch2=bottleneck(
+ dim_in=dim_in,
+ dim_inner=dim_inner,
+ dim_out=dim_out,
+ conv_a_kernel_size=conv_a_kernel_size,
+ conv_a_stride=conv_a_stride,
+ conv_a_padding=conv_a_padding,
+ conv_a=conv_a,
+ conv_b_kernel_size=conv_b_kernel_size,
+ conv_b_stride=conv_b_stride,
+ conv_b_padding=conv_b_padding,
+ conv_b_num_groups=conv_b_num_groups,
+ conv_b_dilation=conv_b_dilation,
+ conv_b=conv_b,
+ conv_c=conv_c,
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ activation=activation_bottleneck,
+ ),
+ activation=None if activation_block is None else activation_block(),
+ branch_fusion=branch_fusion,
+ )
+
+
+def create_res_stage(
+ *,
+ # Stage configs.
+ depth: int,
+ # Bottleneck Block configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ bottleneck: Callable,
+ # Conv configs.
+ conv_a_kernel_size: Union[Tuple[int], List[Tuple[int]]] = (3, 1, 1),
+ conv_a_stride: Tuple[int] = (2, 1, 1),
+ conv_a_padding: Union[Tuple[int], List[Tuple[int]]] = (1, 0, 0),
+ conv_a: Callable = nn.Conv3d,
+ conv_b_kernel_size: Tuple[int] = (1, 3, 3),
+ conv_b_stride: Tuple[int] = (1, 2, 2),
+ conv_b_padding: Tuple[int] = (0, 1, 1),
+ conv_b_num_groups: int = 1,
+ conv_b_dilation: Tuple[int] = (1, 1, 1),
+ conv_b: Callable = nn.Conv3d,
+ conv_c: Callable = nn.Conv3d,
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Create Residual Stage, which composes sequential blocks that make up a ResNet. These
+ blocks could be, for example, Residual blocks, Non-Local layers, or
+ Squeeze-Excitation layers.
+
+ ::
+
+
+ Input
+ ↓
+ ResBlock
+ ↓
+ .
+ .
+ .
+ ↓
+ ResBlock
+
+ Normalization examples include: BatchNorm3d and None (no normalization).
+ Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
+ Bottleneck examples include: create_bottleneck_block.
+
+ Args:
+ depth (init): number of blocks to create.
+
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ bottleneck (callable): a callable that constructs bottleneck block layer.
+ Examples include: create_bottleneck_block.
+
+ conv_a_kernel_size (tuple or list of tuple): convolutional kernel size(s)
+ for conv_a. If conv_a_kernel_size is a tuple, use it for all blocks in
+ the stage. If conv_a_kernel_size is a list of tuple, the kernel sizes
+ will be repeated until having same length of depth in the stage. For
+ example, for conv_a_kernel_size = [(3, 1, 1), (1, 1, 1)], the kernel
+ size for the first 6 blocks would be [(3, 1, 1), (1, 1, 1), (3, 1, 1),
+ (1, 1, 1), (3, 1, 1)].
+ conv_a_stride (tuple): convolutional stride size(s) for conv_a.
+ conv_a_padding (tuple or list of tuple): convolutional padding(s) for
+ conv_a. If conv_a_padding is a tuple, use it for all blocks in
+ the stage. If conv_a_padding is a list of tuple, the padding sizes
+ will be repeated until having same length of depth in the stage.
+ conv_a (callable): a callable that constructs the conv_a conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_b_stride (tuple): convolutional stride size(s) for conv_b.
+ conv_b_padding (tuple): convolutional padding(s) for conv_b.
+ conv_b_num_groups (int): number of groups for groupwise convolution for
+ conv_b.
+ conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ conv_b (callable): a callable that constructs the conv_b conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+ conv_c (callable): a callable that constructs the conv_c conv layer, examples
+ include nn.Conv3d, OctaveConv, etc
+
+ norm (callable): a callable that constructs normalization layer. Examples
+ include nn.BatchNorm3d, and None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer. Examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): resnet basic stage layer.
+ """
+ res_blocks = []
+ if isinstance(conv_a_kernel_size[0], int):
+ conv_a_kernel_size = [conv_a_kernel_size]
+ if isinstance(conv_a_padding[0], int):
+ conv_a_padding = [conv_a_padding]
+ # Repeat conv_a kernels until having same length of depth in the stage.
+ conv_a_kernel_size = (conv_a_kernel_size * depth)[:depth]
+ conv_a_padding = (conv_a_padding * depth)[:depth]
+
+ for ind in range(depth):
+ block = create_res_block(
+ dim_in=dim_in if ind == 0 else dim_out,
+ dim_inner=dim_inner,
+ dim_out=dim_out,
+ bottleneck=bottleneck,
+ conv_a_kernel_size=conv_a_kernel_size[ind],
+ conv_a_stride=conv_a_stride if ind == 0 else (1, 1, 1),
+ conv_a_padding=conv_a_padding[ind],
+ conv_a=conv_a,
+ conv_b_kernel_size=conv_b_kernel_size,
+ conv_b_stride=conv_b_stride if ind == 0 else (1, 1, 1),
+ conv_b_padding=conv_b_padding,
+ conv_b_num_groups=conv_b_num_groups,
+ conv_b_dilation=conv_b_dilation,
+ conv_b=conv_b,
+ conv_c=conv_c,
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ activation_bottleneck=activation,
+ activation_block=activation,
+ )
+ res_blocks.append(block)
+ return ResStage(res_blocks=nn.ModuleList(res_blocks))
+
+
+# Number of blocks for different stages given the model depth.
+_MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
+
+
+def create_resnet(
+ *,
+ # Input clip configs.
+ input_channel: int = 3,
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 400,
+ dropout_rate: float = 0.5,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_dim_out: int = 64,
+ stem_conv_kernel_size: Tuple[int] = (3, 7, 7),
+ stem_conv_stride: Tuple[int] = (1, 2, 2),
+ stem_pool: Callable = nn.MaxPool3d,
+ stem_pool_kernel_size: Tuple[int] = (1, 3, 3),
+ stem_pool_stride: Tuple[int] = (1, 2, 2),
+ stem: Callable = create_res_basic_stem,
+ # Stage configs.
+ stage1_pool: Callable = None,
+ stage1_pool_kernel_size: Tuple[int] = (2, 1, 1),
+ stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
+ (1, 1, 1),
+ (1, 1, 1),
+ (3, 1, 1),
+ (3, 1, 1),
+ ),
+ stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
+ (1, 3, 3),
+ (1, 3, 3),
+ (1, 3, 3),
+ (1, 3, 3),
+ ),
+ stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1),
+ stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = (
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ ),
+ stage_spatial_h_stride: Tuple[int] = (1, 2, 2, 2),
+ stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 2),
+ stage_temporal_stride: Tuple[int] = (1, 1, 1, 1),
+ bottleneck: Union[Tuple[Callable], Callable] = create_bottleneck_block,
+ # Head configs.
+ head: Callable = create_res_basic_head,
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_size: Tuple[int] = (4, 7, 7),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = None,
+ head_output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Build ResNet style models for video recognition. ResNet has three parts:
+ Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an
+ optional pooling layer. Stages are grouped residual blocks. There are usually
+ multiple stages and each stage may include multiple residual blocks. Head
+ may include pooling, dropout, a fully-connected layer and global spatial
+ temporal averaging. The three parts are assembled in the following order:
+
+ ::
+
+ Input
+ ↓
+ Stem
+ ↓
+ Stage 1
+ ↓
+ .
+ .
+ .
+ ↓
+ Stage N
+ ↓
+ Head
+
+ Args:
+
+ input_channel (int): number of channels for the input video clip.
+
+ model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+
+ norm (callable): a callable that constructs normalization layer.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_dim_out (int): output channel size to stem.
+ stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
+ stem_conv_stride (tuple): convolutional stride size(s) of stem.
+ stem_pool (callable): a callable that constructs resnet head pooling layer.
+ stem_pool_kernel_size (tuple): pooling kernel size(s).
+ stem_pool_stride (tuple): pooling stride size(s).
+ stem (callable): a callable that constructs stem layer.
+ Examples include: create_res_video_stem.
+
+ stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
+ for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
+ stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ stage_spatial_h_stride (tuple): the spatial height stride for each stage.
+ stage_spatial_w_stride (tuple): the spatial width stride for each stage.
+ stage_temporal_stride (tuple): the temporal stride for each stage.
+ bottleneck (callable): a callable that constructs bottleneck block layer.
+ Examples include: create_bottleneck_block.
+
+ head (callable): a callable that constructs the resnet-style head.
+ Ex: create_res_basic_head
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_pool_kernel_size (tuple): the pooling kernel size.
+ head_output_size (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+
+ Returns:
+ (nn.Module): basic resnet.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_resnet")
+
+ # Given a model depth, get the number of blocks for each stage.
+ assert (
+ model_depth in _MODEL_STAGE_DEPTH.keys()
+ ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
+ stage_depths = _MODEL_STAGE_DEPTH[model_depth]
+
+ # Broadcast single element to tuple if given.
+ if isinstance(stage_conv_a_kernel_size[0], int):
+ stage_conv_a_kernel_size = (stage_conv_a_kernel_size,) * len(stage_depths)
+
+ if isinstance(stage_conv_b_kernel_size[0], int):
+ stage_conv_b_kernel_size = (stage_conv_b_kernel_size,) * len(stage_depths)
+
+ if isinstance(stage_conv_b_dilation[0], int):
+ stage_conv_b_dilation = (stage_conv_b_dilation,) * len(stage_depths)
+
+ if isinstance(bottleneck, Callable):
+ bottleneck = [
+ bottleneck,
+ ] * len(stage_depths)
+
+ blocks = []
+ # Create stem for resnet.
+ stem = stem(
+ in_channels=input_channel,
+ out_channels=stem_dim_out,
+ conv_kernel_size=stem_conv_kernel_size,
+ conv_stride=stem_conv_stride,
+ conv_padding=[size // 2 for size in stem_conv_kernel_size],
+ pool=stem_pool,
+ pool_kernel_size=stem_pool_kernel_size,
+ pool_stride=stem_pool_stride,
+ pool_padding=[size // 2 for size in stem_pool_kernel_size],
+ norm=norm,
+ activation=activation,
+ )
+ blocks.append(stem)
+
+ stage_dim_in = stem_dim_out
+ stage_dim_out = stage_dim_in * 4
+
+ # Create each stage for resnet.
+ for idx in range(len(stage_depths)):
+ stage_dim_inner = stage_dim_out // 4
+ depth = stage_depths[idx]
+
+ stage_conv_a_kernel = stage_conv_a_kernel_size[idx]
+ stage_conv_a_stride = (stage_temporal_stride[idx], 1, 1)
+ stage_conv_a_padding = (
+ [size // 2 for size in stage_conv_a_kernel]
+ if isinstance(stage_conv_a_kernel[0], int)
+ else [[size // 2 for size in sizes] for sizes in stage_conv_a_kernel]
+ )
+
+ stage_conv_b_stride = (
+ 1,
+ stage_spatial_h_stride[idx],
+ stage_spatial_w_stride[idx],
+ )
+
+ stage = create_res_stage(
+ depth=depth,
+ dim_in=stage_dim_in,
+ dim_inner=stage_dim_inner,
+ dim_out=stage_dim_out,
+ bottleneck=bottleneck[idx],
+ conv_a_kernel_size=stage_conv_a_kernel,
+ conv_a_stride=stage_conv_a_stride,
+ conv_a_padding=stage_conv_a_padding,
+ conv_b_kernel_size=stage_conv_b_kernel_size[idx],
+ conv_b_stride=stage_conv_b_stride,
+ conv_b_padding=(
+ stage_conv_b_kernel_size[idx][0] // 2,
+ stage_conv_b_dilation[idx][1]
+ if stage_conv_b_dilation[idx][1] > 1
+ else stage_conv_b_kernel_size[idx][1] // 2,
+ stage_conv_b_dilation[idx][2]
+ if stage_conv_b_dilation[idx][2] > 1
+ else stage_conv_b_kernel_size[idx][2] // 2,
+ ),
+ conv_b_num_groups=stage_conv_b_num_groups[idx],
+ conv_b_dilation=stage_conv_b_dilation[idx],
+ norm=norm,
+ activation=activation,
+ )
+
+ blocks.append(stage)
+ stage_dim_in = stage_dim_out
+ stage_dim_out = stage_dim_out * 2
+
+ if idx == 0 and stage1_pool is not None:
+ blocks.append(
+ stage1_pool(
+ kernel_size=stage1_pool_kernel_size,
+ stride=stage1_pool_kernel_size,
+ padding=(0, 0, 0),
+ )
+ )
+ if head is not None:
+ head = head(
+ in_features=stage_dim_in,
+ out_features=model_num_class,
+ pool=head_pool,
+ output_size=head_output_size,
+ pool_kernel_size=head_pool_kernel_size,
+ dropout_rate=dropout_rate,
+ activation=head_activation,
+ output_with_global_average=head_output_with_global_average,
+ )
+ blocks.append(head)
+ return Net(blocks=nn.ModuleList(blocks))
+
+
+def create_resnet_with_roi_head(
+ *,
+ # Input clip configs.
+ input_channel: int = 3,
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 80,
+ dropout_rate: float = 0.5,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_dim_out: int = 64,
+ stem_conv_kernel_size: Tuple[int] = (1, 7, 7),
+ stem_conv_stride: Tuple[int] = (1, 2, 2),
+ stem_pool: Callable = nn.MaxPool3d,
+ stem_pool_kernel_size: Tuple[int] = (1, 3, 3),
+ stem_pool_stride: Tuple[int] = (1, 2, 2),
+ stem: Callable = create_res_basic_stem,
+ # Stage configs.
+ stage1_pool: Callable = None,
+ stage1_pool_kernel_size: Tuple[int] = (2, 1, 1),
+ stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
+ (1, 1, 1),
+ (1, 1, 1),
+ (3, 1, 1),
+ (3, 1, 1),
+ ),
+ stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
+ (1, 3, 3),
+ (1, 3, 3),
+ (1, 3, 3),
+ (1, 3, 3),
+ ),
+ stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1),
+ stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = (
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 1, 1),
+ (1, 2, 2),
+ ),
+ stage_spatial_h_stride: Tuple[int] = (1, 2, 2, 1),
+ stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 1),
+ stage_temporal_stride: Tuple[int] = (1, 1, 1, 1),
+ bottleneck: Union[Tuple[Callable], Callable] = create_bottleneck_block,
+ # Head configs.
+ head: Callable = create_res_roi_pooling_head,
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_size: Tuple[int] = (4, 1, 1),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = nn.Sigmoid,
+ head_output_with_global_average: bool = False,
+ head_spatial_resolution: Tuple[int] = (7, 7),
+ head_spatial_scale: float = 1.0 / 16.0,
+ head_sampling_ratio: int = 0,
+) -> nn.Module:
+ """
+ Build ResNet style models for video detection. ResNet has three parts:
+ Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an
+ optional pooling layer. Stages are grouped residual blocks. There are usually
+ multiple stages and each stage may include multiple residual blocks. Head
+ may include pooling, dropout, a fully-connected layer and global spatial
+ temporal averaging. The three parts are assembled in the following order:
+
+ ::
+
+ Input Clip Input Bounding Boxes
+ ↓ ↓
+ Stem ↓
+ ↓ ↓
+ Stage 1 ↓
+ ↓ ↓
+ . ↓
+ . ↓
+ . ↓
+ ↓ ↓
+ Stage N ↓
+ ↓--------> Head <-------↓
+
+ Args:
+
+ input_channel (int): number of channels for the input video clip.
+
+ model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+
+ norm (callable): a callable that constructs normalization layer.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_dim_out (int): output channel size to stem.
+ stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
+ stem_conv_stride (tuple): convolutional stride size(s) of stem.
+ stem_pool (callable): a callable that constructs resnet head pooling layer.
+ stem_pool_kernel_size (tuple): pooling kernel size(s).
+ stem_pool_stride (tuple): pooling stride size(s).
+ stem (callable): a callable that constructs stem layer.
+ Examples include: create_res_video_stem.
+
+ stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
+ for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
+ stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ stage_spatial_h_stride (tuple): the spatial height stride for each stage.
+ stage_spatial_w_stride (tuple): the spatial width stride for each stage.
+ stage_temporal_stride (tuple): the temporal stride for each stage.
+ bottleneck (callable): a callable that constructs bottleneck block layer.
+ Examples include: create_bottleneck_block.
+
+ head (callable): a callable that constructs the detection head which can
+ take in the additional input of bounding boxes.
+ Ex: create_res_roi_pooling_head
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_pool_kernel_size (tuple): the pooling kernel size.
+ head_output_size (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+ head_spatial_resolution (tuple): h, w sizes of the RoI interpolation.
+ head_spatial_scale (float): scale the input boxes by this number.
+ head_sampling_ratio (int): number of inputs samples to take for each output
+ sample interpolation. 0 to take samples densely.
+
+ Returns:
+ (nn.Module): basic resnet.
+ """
+
+ model = create_resnet(
+ # Input clip configs.
+ input_channel=input_channel,
+ # Model configs.
+ model_depth=model_depth,
+ model_num_class=model_num_class,
+ dropout_rate=dropout_rate,
+ # Normalization configs.
+ norm=norm,
+ # Activation configs.
+ activation=activation,
+ # Stem configs.
+ stem_dim_out=stem_dim_out,
+ stem_conv_kernel_size=stem_conv_kernel_size,
+ stem_conv_stride=stem_conv_stride,
+ stem_pool=stem_pool,
+ stem_pool_kernel_size=stem_pool_kernel_size,
+ stem_pool_stride=stem_pool_stride,
+ # Stage configs.
+ stage1_pool=stage1_pool,
+ stage_conv_a_kernel_size=stage_conv_a_kernel_size,
+ stage_conv_b_kernel_size=stage_conv_b_kernel_size,
+ stage_conv_b_num_groups=stage_conv_b_num_groups,
+ stage_conv_b_dilation=stage_conv_b_dilation,
+ stage_spatial_h_stride=stage_spatial_h_stride,
+ stage_spatial_w_stride=stage_spatial_w_stride,
+ stage_temporal_stride=stage_temporal_stride,
+ bottleneck=bottleneck,
+ # Head configs.
+ head=None,
+ )
+ head = head(
+ in_features=stem_dim_out * 2 ** (len(_MODEL_STAGE_DEPTH[model_depth]) + 1),
+ out_features=model_num_class,
+ pool=head_pool,
+ output_size=head_output_size,
+ pool_kernel_size=head_pool_kernel_size,
+ dropout_rate=dropout_rate,
+ activation=head_activation,
+ output_with_global_average=head_output_with_global_average,
+ resolution=head_spatial_resolution,
+ spatial_scale=head_spatial_scale,
+ sampling_ratio=head_sampling_ratio,
+ )
+ return DetectionBBoxNetwork(model, head)
+
+
+def create_acoustic_resnet(
+ *,
+ # Input clip configs.
+ input_channel: int = 1,
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 400,
+ dropout_rate: float = 0.5,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_dim_out: int = 64,
+ stem_conv_kernel_size: Tuple[int] = (9, 1, 9),
+ stem_conv_stride: Tuple[int] = (1, 1, 3),
+ stem_pool: Callable = None,
+ stem_pool_kernel_size: Tuple[int] = (3, 1, 3),
+ stem_pool_stride: Tuple[int] = (2, 1, 2),
+ stem: Callable = create_acoustic_res_basic_stem,
+ # Stage configs.
+ stage1_pool: Callable = None,
+ stage1_pool_kernel_size: Tuple[int] = (2, 1, 1),
+ stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (3, 1, 1),
+ stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (3, 1, 3),
+ stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1),
+ stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = (1, 1, 1),
+ stage_spatial_h_stride: Tuple[int] = (1, 1, 1, 1),
+ stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 2),
+ stage_temporal_stride: Tuple[int] = (1, 2, 2, 2),
+ bottleneck: Union[Tuple[Callable], Callable] = (
+ create_acoustic_bottleneck_block,
+ create_acoustic_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ # Head configs.
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_size: Tuple[int] = (4, 1, 2),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = None,
+ head_output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Build ResNet style models for acoustic recognition. ResNet has three parts:
+ Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an
+ optional pooling layer. Stages are grouped residual blocks. There are usually
+ multiple stages and each stage may include multiple residual blocks. Head
+ may include pooling, dropout, a fully-connected layer and global spatial
+ temporal averaging. The three parts are assembled in the following order:
+
+ ::
+
+ Input
+ ↓
+ Stem
+ ↓
+ Stage 1
+ ↓
+ .
+ .
+ .
+ ↓
+ Stage N
+ ↓
+ Head
+
+ Args:
+
+ input_channel (int): number of channels for the input video clip.
+
+ model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+
+ norm (callable): a callable that constructs normalization layer.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_dim_out (int): output channel size to stem.
+ stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
+ stem_conv_stride (tuple): convolutional stride size(s) of stem.
+ stem_pool (callable): a callable that constructs resnet head pooling layer.
+ stem_pool_kernel_size (tuple): pooling kernel size(s).
+ stem_pool_stride (tuple): pooling stride size(s).
+ stem (callable): a callable that constructs stem layer.
+ Examples include: create_res_video_stem.
+
+ stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
+ for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
+ stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
+ stage_spatial_h_stride (tuple): the spatial height stride for each stage.
+ stage_spatial_w_stride (tuple): the spatial width stride for each stage.
+ stage_temporal_stride (tuple): the temporal stride for each stage.
+ bottleneck (callable): a callable that constructs bottleneck block layer.
+ Examples include: create_bottleneck_block.
+
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_pool_kernel_size (tuple): the pooling kernel size.
+ head_output_size (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+
+ Returns:
+ (nn.Module): audio resnet, that takes spectragram image input with
+ shape: (B, C, T, 1, F), where T is the time dimension and F is the
+ frequency dimension.
+ """
+ return create_resnet(**locals())
+
+
+class ResBlock(nn.Module):
+ """
+ Residual block. Performs a summation between an identity shortcut in branch1 and a
+ main block in branch2. When the input and output dimensions are different, a
+ convolution followed by a normalization will be performed.
+
+ ::
+
+
+ Input
+ |-------+
+ ↓ |
+ Block |
+ ↓ |
+ Summation ←-+
+ ↓
+ Activation
+
+ The builder can be found in `create_res_block`.
+ """
+
+ def __init__(
+ self,
+ branch1_conv: nn.Module = None,
+ branch1_norm: nn.Module = None,
+ branch2: nn.Module = None,
+ activation: nn.Module = None,
+ branch_fusion: Callable = None,
+ ) -> nn.Module:
+ """
+ Args:
+ branch1_conv (torch.nn.modules): convolutional module in branch1.
+ branch1_norm (torch.nn.modules): normalization module in branch1.
+ branch2 (torch.nn.modules): bottleneck block module in branch2.
+ activation (torch.nn.modules): activation module.
+ branch_fusion: (Callable): A callable or layer that combines branch1
+ and branch2.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.branch2 is not None
+
+ def forward(self, x) -> torch.Tensor:
+ if self.branch1_conv is None:
+ x = self.branch_fusion(x, self.branch2(x))
+ else:
+ shortcut = self.branch1_conv(x)
+ if self.branch1_norm is not None:
+ shortcut = self.branch1_norm(shortcut)
+ x = self.branch_fusion(shortcut, self.branch2(x))
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
+
+class SeparableBottleneckBlock(nn.Module):
+ """
+ Separable Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
+ and Activations repeated in the following order. Requires a tuple of models to be
+ provided to conv_b, norm_b, act_b to perform Convolution, Normalization, and
+ Activations in parallel Separably.
+
+ ::
+
+
+ Conv3d (conv_a)
+ ↓
+ Normalization (norm_a)
+ ↓
+ Activation (act_a)
+ ↓
+ Conv3d(s) (conv_b), ...
+ ↓ (↓)
+ Normalization(s) (norm_b), ...
+ ↓ (↓)
+ Activation(s) (act_b), ...
+ ↓ (↓)
+ Reduce (sum or cat)
+ ↓
+ Conv3d (conv_c)
+ ↓
+ Normalization (norm_c)
+ """
+
+ def __init__(
+ self,
+ *,
+ conv_a: nn.Module,
+ norm_a: nn.Module,
+ act_a: nn.Module,
+ conv_b: nn.ModuleList,
+ norm_b: nn.ModuleList,
+ act_b: nn.ModuleList,
+ conv_c: nn.Module,
+ norm_c: nn.Module,
+ reduce_method: str = "sum",
+ ) -> None:
+ """
+ Args:
+ conv_a (torch.nn.modules): convolutional module.
+ norm_a (torch.nn.modules): normalization module.
+ act_a (torch.nn.modules): activation module.
+ conv_b (torch.nn.modules_list): convolutional module(s).
+ norm_b (torch.nn.modules_list): normalization module(s).
+ act_b (torch.nn.modules_list): activation module(s).
+ conv_c (torch.nn.modules): convolutional module.
+ norm_c (torch.nn.modules): normalization module.
+ reduce_method (str): if multiple conv_b is used, reduce the output with
+ `sum`, or `cat`.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert all(
+ op is not None for op in (self.conv_b, self.conv_c)
+ ), f"{self.conv_a}, {self.conv_b}, {self.conv_c} has None"
+ assert reduce_method in ["sum", "cat"]
+ if self.norm_c is not None:
+ # This flag is used for weight initialization.
+ self.norm_c.block_final_bn = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Explicitly forward every layer.
+ # Branch2a, for example Tx1x1, BN, ReLU.
+ if self.conv_a is not None:
+ x = self.conv_a(x)
+ if self.norm_a is not None:
+ x = self.norm_a(x)
+ if self.act_a is not None:
+ x = self.act_a(x)
+
+ # Branch2b, for example 1xHxW, BN, ReLU.
+ output = []
+ for ind in range(len(self.conv_b)):
+ x_ = self.conv_b[ind](x)
+ if self.norm_b[ind] is not None:
+ x_ = self.norm_b[ind](x_)
+ if self.act_b[ind] is not None:
+ x_ = self.act_b[ind](x_)
+ output.append(x_)
+ if self.reduce_method == "sum":
+ x = torch.stack(output, dim=0).sum(dim=0, keepdim=False)
+ elif self.reduce_method == "cat":
+ x = torch.cat(output, dim=1)
+
+ # Branch2c, for example 1x1x1, BN.
+ x = self.conv_c(x)
+ if self.norm_c is not None:
+ x = self.norm_c(x)
+ return x
+
+
+class BottleneckBlock(nn.Module):
+ """
+ Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
+ and Activations repeated in the following order:
+
+ ::
+
+
+ Conv3d (conv_a)
+ ↓
+ Normalization (norm_a)
+ ↓
+ Activation (act_a)
+ ↓
+ Conv3d (conv_b)
+ ↓
+ Normalization (norm_b)
+ ↓
+ Activation (act_b)
+ ↓
+ Conv3d (conv_c)
+ ↓
+ Normalization (norm_c)
+
+ The builder can be found in `create_bottleneck_block`.
+ """
+
+ def __init__(
+ self,
+ *,
+ conv_a: nn.Module = None,
+ norm_a: nn.Module = None,
+ act_a: nn.Module = None,
+ conv_b: nn.Module = None,
+ norm_b: nn.Module = None,
+ act_b: nn.Module = None,
+ conv_c: nn.Module = None,
+ norm_c: nn.Module = None,
+ ) -> None:
+ """
+ Args:
+ conv_a (torch.nn.modules): convolutional module.
+ norm_a (torch.nn.modules): normalization module.
+ act_a (torch.nn.modules): activation module.
+ conv_b (torch.nn.modules): convolutional module.
+ norm_b (torch.nn.modules): normalization module.
+ act_b (torch.nn.modules): activation module.
+ conv_c (torch.nn.modules): convolutional module.
+ norm_c (torch.nn.modules): normalization module.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert all(op is not None for op in (self.conv_a, self.conv_b, self.conv_c))
+ if self.norm_c is not None:
+ # This flag is used for weight initialization.
+ self.norm_c.block_final_bn = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Explicitly forward every layer.
+ # Branch2a, for example Tx1x1, BN, ReLU.
+ x = self.conv_a(x)
+ if self.norm_a is not None:
+ x = self.norm_a(x)
+ if self.act_a is not None:
+ x = self.act_a(x)
+
+ # Branch2b, for example 1xHxW, BN, ReLU.
+ x = self.conv_b(x)
+ if self.norm_b is not None:
+ x = self.norm_b(x)
+ if self.act_b is not None:
+ x = self.act_b(x)
+
+ # Branch2c, for example 1x1x1, BN.
+ x = self.conv_c(x)
+ if self.norm_c is not None:
+ x = self.norm_c(x)
+ return x
+
+
+class ResStage(nn.Module):
+ """
+ ResStage composes sequential blocks that make up a ResNet. These blocks could be,
+ for example, Residual blocks, Non-Local layers, or Squeeze-Excitation layers.
+
+ ::
+
+
+ Input
+ ↓
+ ResBlock
+ ↓
+ .
+ .
+ .
+ ↓
+ ResBlock
+
+ The builder can be found in `create_res_stage`.
+ """
+
+ def __init__(self, res_blocks: nn.ModuleList) -> nn.Module:
+ """
+ Args:
+ res_blocks (torch.nn.module_list): ResBlock module(s).
+ """
+ super().__init__()
+ self.res_blocks = res_blocks
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for _, res_block in enumerate(self.res_blocks):
+ x = res_block(x)
+ return x
diff --git a/code/pytorchvideo/pytorchvideo/models/simclr.py b/code/pytorchvideo/pytorchvideo/models/simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..8619ab01f809457d644138040af300d666efb5f9
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/simclr.py
@@ -0,0 +1,66 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from fvcore.nn.distributed import differentiable_all_gather
+from pytorchvideo.layers.utils import set_attributes
+
+
+class SimCLR(nn.Module):
+ """
+ A Simple Framework for Contrastive Learning of Visual Representations
+ Details can be found from:
+ https://arxiv.org/abs/2002.05709
+ """
+
+ def __init__(
+ self,
+ mlp: nn.Module,
+ backbone: Optional[nn.Module] = None,
+ temperature: float = 0.07,
+ ) -> None:
+ super().__init__()
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.SimCLR.__init__")
+
+ set_attributes(self, locals())
+
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x1 (torch.tensor): a batch of image with augmentation. The input tensor
+ shape should able to be feed into the backbone.
+ x2 (torch.tensor): the size batch of image with different augmentation. The
+ input tensor shape should able to be feed into the backbone.
+ """
+ if self.backbone is not None:
+ x1 = self.backbone(x1)
+ x1 = self.mlp(x1)
+ x1 = F.normalize(x1, p=2, dim=1)
+
+ if self.backbone is not None:
+ x2 = self.backbone(x2)
+ x2 = self.mlp(x2)
+ x2 = F.normalize(x2, p=2, dim=1)
+ x2 = torch.cat(differentiable_all_gather(x2), dim=0)
+
+ prod = torch.einsum("nc,kc->nk", [x1, x2])
+ prod = prod.div(self.temperature)
+ batch_size = x1.size(0)
+ if dist.is_available() and dist.is_initialized():
+ device_ind = dist.get_rank()
+ else:
+ device_ind = 0
+ gt = (
+ torch.tensor(
+ list(range(device_ind * batch_size, (device_ind + 1) * batch_size))
+ )
+ .long()
+ .to(x1.device)
+ )
+ loss = torch.nn.functional.cross_entropy(prod, gt)
+ return loss
diff --git a/code/pytorchvideo/pytorchvideo/models/slowfast.py b/code/pytorchvideo/pytorchvideo/models/slowfast.py
new file mode 100644
index 0000000000000000000000000000000000000000..15a89bcb10fe1e5e564f077308d25db9f04fd830
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/slowfast.py
@@ -0,0 +1,725 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.utils import set_attributes
+from pytorchvideo.models.head import create_res_basic_head, create_res_roi_pooling_head
+from pytorchvideo.models.net import DetectionBBoxNetwork, MultiPathWayWithFuse, Net
+from pytorchvideo.models.resnet import create_bottleneck_block, create_res_stage
+from pytorchvideo.models.stem import create_res_basic_stem
+
+
+_MODEL_STAGE_DEPTH = {
+ 18: (1, 1, 1, 1),
+ 50: (3, 4, 6, 3),
+ 101: (3, 4, 23, 3),
+ 152: (3, 8, 36, 3),
+}
+
+
+def create_slowfast(
+ *,
+ # SlowFast configs.
+ slowfast_channel_reduction_ratio: Union[Tuple[int], int] = (8,),
+ slowfast_conv_channel_fusion_ratio: int = 2,
+ slowfast_fusion_conv_kernel_size: Tuple[int] = (
+ 7,
+ 1,
+ 1,
+ ), # deprecated, use fusion_builder
+ slowfast_fusion_conv_stride: Tuple[int] = (
+ 4,
+ 1,
+ 1,
+ ), # deprecated, use fusion_builder
+ fusion_builder: Callable[
+ [int, int], nn.Module
+ ] = None, # Args: fusion_dim_in, stage_idx
+ # Input clip configs.
+ input_channels: Tuple[int] = (3, 3),
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 400,
+ dropout_rate: float = 0.5,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_function: Tuple[Callable] = (
+ create_res_basic_stem,
+ create_res_basic_stem,
+ ),
+ stem_dim_outs: Tuple[int] = (64, 8),
+ stem_conv_kernel_sizes: Tuple[Tuple[int]] = ((1, 7, 7), (5, 7, 7)),
+ stem_conv_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2)),
+ stem_pool: Union[Callable, Tuple[Callable]] = (nn.MaxPool3d, nn.MaxPool3d),
+ stem_pool_kernel_sizes: Tuple[Tuple[int]] = ((1, 3, 3), (1, 3, 3)),
+ stem_pool_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2)),
+ # Stage configs.
+ stage_conv_a_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 1, 1), (1, 1, 1), (3, 1, 1), (3, 1, 1)),
+ ((3, 1, 1), (3, 1, 1), (3, 1, 1), (3, 1, 1)),
+ ),
+ stage_conv_b_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
+ ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
+ ),
+ stage_conv_b_num_groups: Tuple[Tuple[int]] = ((1, 1, 1, 1), (1, 1, 1, 1)),
+ stage_conv_b_dilations: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1)),
+ ),
+ stage_spatial_strides: Tuple[Tuple[int]] = ((1, 2, 2, 2), (1, 2, 2, 2)),
+ stage_temporal_strides: Tuple[Tuple[int]] = ((1, 1, 1, 1), (1, 1, 1, 1)),
+ bottleneck: Union[Callable, Tuple[Tuple[Callable]]] = (
+ (
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ (
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ ),
+ # Head configs.
+ head: Callable = create_res_basic_head,
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_sizes: Tuple[Tuple[int]] = ((8, 7, 7), (32, 7, 7)),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = None,
+ head_output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Build SlowFast model for video recognition, SlowFast model involves a Slow pathway,
+ operating at low frame rate, to capture spatial semantics, and a Fast pathway,
+ operating at high frame rate, to capture motion at fine temporal resolution. The
+ Fast pathway can be made very lightweight by reducing its channel capacity, yet can
+ learn useful temporal information for video recognition. Details can be found from
+ the paper:
+
+ Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
+ "SlowFast networks for video recognition."
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ ::
+
+ Slow Input Fast Input
+ ↓ ↓
+ Stem Stem
+ ↓ ⭠ Fusion- ↓
+ Stage 1 Stage 1
+ ↓ ⭠ Fusion- ↓
+ . .
+ ↓ ↓
+ Stage N Stage N
+ ↓ ⭠ Fusion- ↓
+ ↓
+ Head
+
+ Args:
+ slowfast_channel_reduction_ratio (int): Corresponds to the inverse of the channel
+ reduction ratio, $\beta$ between the Slow and Fast pathways.
+ slowfast_conv_channel_fusion_ratio (int): Ratio of channel dimensions
+ between the Slow and Fast pathways.
+ DEPRECATED slowfast_fusion_conv_kernel_size (tuple): the convolutional kernel
+ size used for fusion.
+ DEPRECATED slowfast_fusion_conv_stride (tuple): the convolutional stride size
+ used for fusion.
+ fusion_builder (Callable[[int, int], nn.Module]): Builder function for generating
+ the fusion modules based on stage dimension and index
+
+ input_channels (tuple): number of channels for the input video clip.
+
+ model_depth (int): the depth of the resnet.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+ norm (callable): a callable that constructs normalization layer.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_function (Tuple[Callable]): a callable that constructs stem layer.
+ Examples include create_res_basic_stem. Indexed by pathway
+ stem_dim_outs (tuple): output channel size to stem.
+ stem_conv_kernel_sizes (tuple): convolutional kernel size(s) of stem.
+ stem_conv_strides (tuple): convolutional stride size(s) of stem.
+ stem_pool (Tuple[Callable]): a callable that constructs resnet head pooling layer.
+ Indexed by pathway
+ stem_pool_kernel_sizes (tuple): pooling kernel size(s).
+ stem_pool_strides (tuple): pooling stride size(s).
+
+ stage_conv_a_kernel_sizes (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_sizes (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
+ for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
+ stage_conv_b_dilations (tuple): dilation for 3D convolution for conv_b.
+ stage_spatial_strides (tuple): the spatial stride for each stage.
+ stage_temporal_strides (tuple): the temporal stride for each stage.
+ bottleneck (Tuple[Tuple[Callable]]): a callable that constructs bottleneck
+ block layer. Examples include: create_bottleneck_block.
+ Indexed by pathway and stage index
+
+ head (callable): a callable that constructs the resnet-style head.
+ Ex: create_res_basic_head
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_output_sizes (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+ Returns:
+ (nn.Module): SlowFast model.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_slowfast")
+
+ # Number of blocks for different stages given the model depth.
+ _num_pathway = len(input_channels)
+ assert (
+ model_depth in _MODEL_STAGE_DEPTH.keys()
+ ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
+ stage_depths = _MODEL_STAGE_DEPTH[model_depth]
+
+ # Fix up inputs
+ if isinstance(slowfast_channel_reduction_ratio, int):
+ slowfast_channel_reduction_ratio = (slowfast_channel_reduction_ratio,)
+ if isinstance(stem_pool, Callable):
+ stem_pool = (stem_pool,) * _num_pathway
+ if isinstance(bottleneck, Callable):
+ bottleneck = (bottleneck,) * len(stage_depths)
+ bottleneck = (bottleneck,) * _num_pathway
+ if fusion_builder is None:
+ fusion_builder = FastToSlowFusionBuilder(
+ slowfast_channel_reduction_ratio=slowfast_channel_reduction_ratio[0],
+ conv_fusion_channel_ratio=slowfast_conv_channel_fusion_ratio,
+ conv_kernel_size=slowfast_fusion_conv_kernel_size,
+ conv_stride=slowfast_fusion_conv_stride,
+ norm=norm,
+ activation=activation,
+ max_stage_idx=len(stage_depths) - 1,
+ ).create_module
+
+ # Build stem blocks.
+ stems = []
+ for pathway_idx in range(_num_pathway):
+ stems.append(
+ stem_function[pathway_idx](
+ in_channels=input_channels[pathway_idx],
+ out_channels=stem_dim_outs[pathway_idx],
+ conv_kernel_size=stem_conv_kernel_sizes[pathway_idx],
+ conv_stride=stem_conv_strides[pathway_idx],
+ conv_padding=[
+ size // 2 for size in stem_conv_kernel_sizes[pathway_idx]
+ ],
+ pool=stem_pool[pathway_idx],
+ pool_kernel_size=stem_pool_kernel_sizes[pathway_idx],
+ pool_stride=stem_pool_strides[pathway_idx],
+ pool_padding=[
+ size // 2 for size in stem_pool_kernel_sizes[pathway_idx]
+ ],
+ norm=norm,
+ activation=activation,
+ )
+ )
+
+ stages = []
+ stages.append(
+ MultiPathWayWithFuse(
+ multipathway_blocks=nn.ModuleList(stems),
+ multipathway_fusion=fusion_builder(
+ fusion_dim_in=stem_dim_outs[0],
+ stage_idx=0,
+ ),
+ )
+ )
+
+ # Build stages blocks.
+ stage_dim_in = stem_dim_outs[0]
+ stage_dim_out = stage_dim_in * 4
+ for idx in range(len(stage_depths)):
+ pathway_stage_dim_in = [
+ stage_dim_in
+ + stage_dim_in
+ * slowfast_conv_channel_fusion_ratio
+ // slowfast_channel_reduction_ratio[0],
+ ]
+ pathway_stage_dim_inner = [
+ stage_dim_out // 4,
+ ]
+ pathway_stage_dim_out = [
+ stage_dim_out,
+ ]
+ for reduction_ratio in slowfast_channel_reduction_ratio:
+ pathway_stage_dim_in = pathway_stage_dim_in + [
+ stage_dim_in // reduction_ratio
+ ]
+ pathway_stage_dim_inner = pathway_stage_dim_inner + [
+ stage_dim_out // 4 // reduction_ratio
+ ]
+ pathway_stage_dim_out = pathway_stage_dim_out + [
+ stage_dim_out // reduction_ratio
+ ]
+
+ stage = []
+ for pathway_idx in range(_num_pathway):
+ depth = stage_depths[idx]
+
+ stage_conv_a_kernel = stage_conv_a_kernel_sizes[pathway_idx][idx]
+ stage_conv_a_stride = (stage_temporal_strides[pathway_idx][idx], 1, 1)
+ stage_conv_a_padding = (
+ [size // 2 for size in stage_conv_a_kernel]
+ if isinstance(stage_conv_a_kernel[0], int)
+ else [[size // 2 for size in sizes] for sizes in stage_conv_a_kernel]
+ )
+
+ stage_conv_b_stride = (
+ 1,
+ stage_spatial_strides[pathway_idx][idx],
+ stage_spatial_strides[pathway_idx][idx],
+ )
+ stage.append(
+ create_res_stage(
+ depth=depth,
+ dim_in=pathway_stage_dim_in[pathway_idx],
+ dim_inner=pathway_stage_dim_inner[pathway_idx],
+ dim_out=pathway_stage_dim_out[pathway_idx],
+ bottleneck=bottleneck[pathway_idx][idx],
+ conv_a_kernel_size=stage_conv_a_kernel,
+ conv_a_stride=stage_conv_a_stride,
+ conv_a_padding=stage_conv_a_padding,
+ conv_b_kernel_size=stage_conv_b_kernel_sizes[pathway_idx][idx],
+ conv_b_stride=stage_conv_b_stride,
+ conv_b_padding=(
+ stage_conv_b_kernel_sizes[pathway_idx][idx][0] // 2,
+ stage_conv_b_dilations[pathway_idx][idx][1]
+ if stage_conv_b_dilations[pathway_idx][idx][1] > 1
+ else stage_conv_b_kernel_sizes[pathway_idx][idx][1] // 2,
+ stage_conv_b_dilations[pathway_idx][idx][2]
+ if stage_conv_b_dilations[pathway_idx][idx][2] > 1
+ else stage_conv_b_kernel_sizes[pathway_idx][idx][2] // 2,
+ ),
+ conv_b_num_groups=stage_conv_b_num_groups[pathway_idx][idx],
+ conv_b_dilation=stage_conv_b_dilations[pathway_idx][idx],
+ norm=norm,
+ activation=activation,
+ )
+ )
+ stages.append(
+ MultiPathWayWithFuse(
+ multipathway_blocks=nn.ModuleList(stage),
+ multipathway_fusion=fusion_builder(
+ fusion_dim_in=stage_dim_out,
+ stage_idx=idx + 1,
+ ),
+ )
+ )
+ stage_dim_in = stage_dim_out
+ stage_dim_out = stage_dim_out * 2
+
+ if head_pool is None:
+ pool_model = None
+ elif head_pool == nn.AdaptiveAvgPool3d:
+ pool_model = [head_pool(head_output_size[idx]) for idx in range(_num_pathway)]
+ elif head_pool == nn.AvgPool3d:
+ pool_model = [
+ head_pool(
+ kernel_size=head_pool_kernel_sizes[idx],
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ )
+ for idx in range(_num_pathway)
+ ]
+ else:
+ raise NotImplementedError(f"Unsupported pool_model type {pool_model}")
+
+ stages.append(PoolConcatPathway(retain_list=False, pool=nn.ModuleList(pool_model)))
+ head_in_features = stage_dim_in
+ for reduction_ratio in slowfast_channel_reduction_ratio:
+ head_in_features = head_in_features + stage_dim_in // reduction_ratio
+ if head is not None:
+ stages.append(
+ head(
+ in_features=head_in_features,
+ out_features=model_num_class,
+ pool=None,
+ output_size=head_output_size,
+ dropout_rate=dropout_rate,
+ activation=head_activation,
+ output_with_global_average=head_output_with_global_average,
+ )
+ )
+ return Net(blocks=nn.ModuleList(stages))
+
+
+def create_slowfast_with_roi_head(
+ *,
+ # SlowFast configs.
+ slowfast_channel_reduction_ratio: Union[Tuple[int], int] = (8,),
+ slowfast_conv_channel_fusion_ratio: int = 2,
+ slowfast_fusion_conv_kernel_size: Tuple[int] = (
+ 7,
+ 1,
+ 1,
+ ), # deprecated, use fusion_builder
+ slowfast_fusion_conv_stride: Tuple[int] = (
+ 4,
+ 1,
+ 1,
+ ), # deprecated, use fusion_builder
+ fusion_builder: Callable[
+ [int, int], nn.Module
+ ] = None, # Args: fusion_dim_in, stage_idx
+ # Input clip configs.
+ input_channels: Tuple[int] = (3, 3),
+ # Model configs.
+ model_depth: int = 50,
+ model_num_class: int = 80,
+ dropout_rate: float = 0.5,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_function: Tuple[Callable] = (
+ create_res_basic_stem,
+ create_res_basic_stem,
+ ),
+ stem_dim_outs: Tuple[int] = (64, 8),
+ stem_conv_kernel_sizes: Tuple[Tuple[int]] = ((1, 7, 7), (5, 7, 7)),
+ stem_conv_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2)),
+ stem_pool: Union[Callable, Tuple[Callable]] = (nn.MaxPool3d, nn.MaxPool3d),
+ stem_pool_kernel_sizes: Tuple[Tuple[int]] = ((1, 3, 3), (1, 3, 3)),
+ stem_pool_strides: Tuple[Tuple[int]] = ((1, 2, 2), (1, 2, 2)),
+ # Stage configs.
+ stage_conv_a_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 1, 1), (1, 1, 1), (3, 1, 1), (3, 1, 1)),
+ ((3, 1, 1), (3, 1, 1), (3, 1, 1), (3, 1, 1)),
+ ),
+ stage_conv_b_kernel_sizes: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
+ ((1, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)),
+ ),
+ stage_conv_b_num_groups: Tuple[Tuple[int]] = ((1, 1, 1, 1), (1, 1, 1, 1)),
+ stage_conv_b_dilations: Tuple[Tuple[Tuple[int]]] = (
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 2, 2)),
+ ((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 2, 2)),
+ ),
+ stage_spatial_strides: Tuple[Tuple[int]] = ((1, 2, 2, 1), (1, 2, 2, 1)),
+ stage_temporal_strides: Tuple[Tuple[int]] = ((1, 1, 1, 1), (1, 1, 1, 1)),
+ bottleneck: Union[Callable, Tuple[Tuple[Callable]]] = (
+ (
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ (
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ create_bottleneck_block,
+ ),
+ ),
+ # Head configs.
+ head: Callable = create_res_roi_pooling_head,
+ head_pool: Callable = nn.AvgPool3d,
+ head_pool_kernel_sizes: Tuple[Tuple[int]] = ((8, 1, 1), (32, 1, 1)),
+ head_output_size: Tuple[int] = (1, 1, 1),
+ head_activation: Callable = nn.Sigmoid,
+ head_output_with_global_average: bool = False,
+ head_spatial_resolution: Tuple[int] = (7, 7),
+ head_spatial_scale: float = 1.0 / 16.0,
+ head_sampling_ratio: int = 0,
+) -> nn.Module:
+ """
+ Build SlowFast model for video detection, SlowFast model involves a Slow pathway,
+ operating at low frame rate, to capture spatial semantics, and a Fast pathway,
+ operating at high frame rate, to capture motion at fine temporal resolution. The
+ Fast pathway can be made very lightweight by reducing its channel capacity, yet can
+ learn useful temporal information for video recognition. Details can be found from
+ the paper:
+
+ Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
+ "SlowFast networks for video recognition."
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ ::
+
+ Slow Input Fast Input Bounding Box Input
+ ↓ ↓ ↓
+ Stem Stem ↓
+ ↓ ⭠ Fusion- ↓ ↓
+ Stage 1 Stage 1 ↓
+ ↓ ⭠ Fusion- ↓ ↓
+ . . ↓
+ ↓ ↓ ↓
+ Stage N Stage N ↓
+ ↓ ⭠ Fusion- ↓ ↓
+ ↓ ↓
+ ↓----------> Head <--------↓
+
+ Args:
+ slowfast_channel_reduction_ratio (int): Corresponds to the inverse of the channel
+ reduction ratio, $\beta$ between the Slow and Fast pathways.
+ slowfast_conv_channel_fusion_ratio (int): Ratio of channel dimensions
+ between the Slow and Fast pathways.
+ DEPRECATED slowfast_fusion_conv_kernel_size (tuple): the convolutional kernel
+ size used for fusion.
+ DEPRECATED slowfast_fusion_conv_stride (tuple): the convolutional stride size
+ used for fusion.
+ fusion_builder (Callable[[int, int], nn.Module]): Builder function for generating
+ the fusion modules based on stage dimension and index
+
+ input_channels (tuple): number of channels for the input video clip.
+
+ model_depth (int): the depth of the resnet.
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+
+ norm (callable): a callable that constructs normalization layer.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_function (Tuple[Callable]): a callable that constructs stem layer.
+ Examples include create_res_basic_stem. Indexed by pathway
+ stem_dim_outs (tuple): output channel size to stem.
+ stem_conv_kernel_sizes (tuple): convolutional kernel size(s) of stem.
+ stem_conv_strides (tuple): convolutional stride size(s) of stem.
+ stem_pool (Tuple[Callable]): a callable that constructs resnet head pooling layer.
+ Indexed by pathway
+ stem_pool_kernel_sizes (tuple): pooling kernel size(s).
+ stem_pool_strides (tuple): pooling stride size(s).
+
+ stage_conv_a_kernel_sizes (tuple): convolutional kernel size(s) for conv_a.
+ stage_conv_b_kernel_sizes (tuple): convolutional kernel size(s) for conv_b.
+ stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
+ for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
+ stage_conv_b_dilations (tuple): dilation for 3D convolution for conv_b.
+ stage_spatial_strides (tuple): the spatial stride for each stage.
+ stage_temporal_strides (tuple): the temporal stride for each stage.
+ bottleneck (Tuple[Tuple[Callable]]): a callable that constructs bottleneck
+ block layer. Examples include: create_bottleneck_block.
+ Indexed by pathway and stage index
+
+ head (callable): a a callable that constructs the detection head which can
+ take in the additional input of bounding boxes.
+ Ex: create_res_roi_pooling_head
+ head_pool (callable): a callable that constructs resnet head pooling layer.
+ head_output_sizes (tuple): the size of output tensor for head.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+ head_spatial_resolution (tuple): h, w sizes of the RoI interpolation.
+ head_spatial_scale (float): scale the input boxes by this number.
+ head_sampling_ratio (int): number of inputs samples to take for each output
+ sample interpolation. 0 to take samples densely.
+ Returns:
+ (nn.Module): SlowFast model.
+ """
+
+ model = create_slowfast(
+ # SlowFast configs.
+ slowfast_channel_reduction_ratio=slowfast_channel_reduction_ratio,
+ slowfast_conv_channel_fusion_ratio=slowfast_conv_channel_fusion_ratio,
+ slowfast_fusion_conv_kernel_size=slowfast_fusion_conv_kernel_size,
+ slowfast_fusion_conv_stride=slowfast_fusion_conv_stride,
+ # Input clip configs.
+ input_channels=input_channels,
+ # Model configs.
+ model_depth=model_depth,
+ model_num_class=model_num_class,
+ dropout_rate=dropout_rate,
+ # Normalization configs.
+ norm=norm,
+ # Activation configs.
+ activation=activation,
+ # Stem configs.
+ stem_dim_outs=stem_dim_outs,
+ stem_conv_kernel_sizes=stem_conv_kernel_sizes,
+ stem_conv_strides=stem_conv_strides,
+ stem_pool=stem_pool,
+ stem_pool_kernel_sizes=stem_pool_kernel_sizes,
+ stem_pool_strides=stem_pool_strides,
+ # Stage configs.
+ stage_conv_a_kernel_sizes=stage_conv_a_kernel_sizes,
+ stage_conv_b_kernel_sizes=stage_conv_b_kernel_sizes,
+ stage_conv_b_num_groups=stage_conv_b_num_groups,
+ stage_conv_b_dilations=stage_conv_b_dilations,
+ stage_spatial_strides=stage_spatial_strides,
+ stage_temporal_strides=stage_temporal_strides,
+ bottleneck=create_bottleneck_block,
+ # Head configs.
+ head=None,
+ head_pool=head_pool,
+ head_pool_kernel_sizes=head_pool_kernel_sizes,
+ )
+
+ stage_dim_out = stem_dim_outs[0] * 2 ** (len(_MODEL_STAGE_DEPTH[model_depth]) + 1)
+ slow_fast_beta = stem_dim_outs[0] // stem_dim_outs[1]
+ head_in_features = stage_dim_out + stage_dim_out // slow_fast_beta
+ head = create_res_roi_pooling_head(
+ in_features=head_in_features,
+ out_features=model_num_class,
+ pool=None,
+ output_size=head_output_size,
+ dropout_rate=dropout_rate,
+ activation=head_activation,
+ output_with_global_average=head_output_with_global_average,
+ resolution=head_spatial_resolution,
+ spatial_scale=head_spatial_scale,
+ sampling_ratio=head_sampling_ratio,
+ )
+ return DetectionBBoxNetwork(model, head)
+
+
+# TODO: move to pytorchvideo/layer once we have a common.py
+class PoolConcatPathway(nn.Module):
+ """
+ Given a list of tensors, perform optional spatio-temporal pool and concatenate the
+ tensors along the channel dimension.
+ """
+
+ def __init__(
+ self,
+ retain_list: bool = False,
+ pool: Optional[nn.ModuleList] = None,
+ dim: int = 1,
+ ) -> None:
+ """
+ Args:
+ retain_list (bool): if True, return the concatenated tensor in a list.
+ pool (nn.module_list): if not None, list of pooling models for different
+ pathway before performing concatenation.
+ dim (int): dimension to performance concatenation.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
+ if self.pool is not None:
+ assert len(x) == len(self.pool)
+ output = []
+ for ind in range(len(x)):
+ if x[ind] is not None:
+ if self.pool is not None and self.pool[ind] is not None:
+ x[ind] = self.pool[ind](x[ind])
+ output.append(x[ind])
+ if self.retain_list:
+ return [torch.cat(output, 1)]
+ else:
+ return torch.cat(output, 1)
+
+
+class FastToSlowFusionBuilder:
+ def __init__(
+ self,
+ slowfast_channel_reduction_ratio: int,
+ conv_fusion_channel_ratio: float,
+ conv_kernel_size: Tuple[int],
+ conv_stride: Tuple[int],
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ activation: Callable = nn.ReLU,
+ max_stage_idx: int = 3,
+ ) -> None:
+ """
+ Given a list of two tensors from Slow pathway and Fast pathway, fusion information
+ from the Fast pathway to the Slow on through a convolution followed by a
+ concatenation, then return the fused list of tensors from Slow and Fast pathway in
+ order.
+ Args:
+ slowfast_channel_reduction_ratio (int): Reduction ratio from the stage dimension.
+ Used to compute conv_dim_in = fusion_dim_in // slowfast_channel_reduction_ratio
+ conv_fusion_channel_ratio (int): channel ratio for the convolution used to fuse
+ from Fast pathway to Slow pathway.
+ conv_kernel_size (int): kernel size of the convolution used to fuse from Fast
+ pathway to Slow pathway.
+ conv_stride (int): stride size of the convolution used to fuse from Fast pathway
+ to Slow pathway.
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+ max_stage_idx (int): Returns identity module if we exceed this
+ """
+ set_attributes(self, locals())
+
+ def create_module(self, fusion_dim_in: int, stage_idx: int) -> nn.Module:
+ """
+ Creates the module for the given stage
+ Args:
+ fusion_dim_in (int): input stage dimension
+ stage_idx (int): which stage this is
+ """
+ if stage_idx > self.max_stage_idx:
+ return nn.Identity()
+
+ conv_dim_in = fusion_dim_in // self.slowfast_channel_reduction_ratio
+ conv_fast_to_slow = nn.Conv3d(
+ conv_dim_in,
+ int(conv_dim_in * self.conv_fusion_channel_ratio),
+ kernel_size=self.conv_kernel_size,
+ stride=self.conv_stride,
+ padding=[k_size // 2 for k_size in self.conv_kernel_size],
+ bias=False,
+ )
+ norm_module = (
+ None
+ if self.norm is None
+ else self.norm(
+ num_features=conv_dim_in * self.conv_fusion_channel_ratio,
+ eps=self.norm_eps,
+ momentum=self.norm_momentum,
+ )
+ )
+ activation_module = None if self.activation is None else self.activation()
+ return FuseFastToSlow(
+ conv_fast_to_slow=conv_fast_to_slow,
+ norm=norm_module,
+ activation=activation_module,
+ )
+
+
+class FuseFastToSlow(nn.Module):
+ """
+ Given a list of two tensors from Slow pathway and Fast pathway, fusion information
+ from the Fast pathway to the Slow on through a convolution followed by a
+ concatenation, then return the fused list of tensors from Slow and Fast pathway in
+ order.
+ """
+
+ def __init__(
+ self,
+ conv_fast_to_slow: nn.Module,
+ norm: Optional[nn.Module] = None,
+ activation: Optional[nn.Module] = None,
+ ) -> None:
+ """
+ Args:
+ conv_fast_to_slow (nn.module): convolution to perform fusion.
+ norm (nn.module): normalization module.
+ activation (torch.nn.modules): activation module.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+
+ def forward(self, x):
+ x_s = x[0]
+ x_f = x[1]
+ fuse = self.conv_fast_to_slow(x_f)
+ if self.norm is not None:
+ fuse = self.norm(fuse)
+ if self.activation is not None:
+ fuse = self.activation(fuse)
+ x_s_fuse = torch.cat([x_s, fuse], 1)
+ return [x_s_fuse, x_f]
diff --git a/code/pytorchvideo/pytorchvideo/models/stem.py b/code/pytorchvideo/pytorchvideo/models/stem.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eab852421200f16a8b31829848e594d1baeef06
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/stem.py
@@ -0,0 +1,338 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.convolutions import ConvReduce3D
+from pytorchvideo.layers.utils import set_attributes
+
+
+def create_res_basic_stem(
+ *,
+ # Conv configs.
+ in_channels: int,
+ out_channels: int,
+ conv_kernel_size: Tuple[int] = (3, 7, 7),
+ conv_stride: Tuple[int] = (1, 2, 2),
+ conv_padding: Tuple[int] = (1, 3, 3),
+ conv_bias: bool = False,
+ conv: Callable = nn.Conv3d,
+ # Pool configs.
+ pool: Callable = nn.MaxPool3d,
+ pool_kernel_size: Tuple[int] = (1, 3, 3),
+ pool_stride: Tuple[int] = (1, 2, 2),
+ pool_padding: Tuple[int] = (0, 1, 1),
+ # BN configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Creates the basic resnet stem layer. It performs spatiotemporal Convolution, BN, and
+ Relu following by a spatiotemporal pooling.
+
+ ::
+
+ Conv3d
+ ↓
+ Normalization
+ ↓
+ Activation
+ ↓
+ Pool3d
+
+ Normalization options include: BatchNorm3d and None (no normalization).
+ Activation options include: ReLU, Softmax, Sigmoid, and None (no activation).
+ Pool3d options include: AvgPool3d, MaxPool3d, and None (no pooling).
+
+ Args:
+
+ in_channels (int): input channel size of the convolution.
+ out_channels (int): output channel size of the convolution.
+ conv_kernel_size (tuple): convolutional kernel size(s).
+ conv_stride (tuple): convolutional stride size(s).
+ conv_padding (tuple): convolutional padding size(s).
+ conv_bias (bool): convolutional bias. If true, adds a learnable bias to the
+ output.
+ conv (callable): Callable used to build the convolution layer.
+
+ pool (callable): a callable that constructs pooling layer, options include:
+ nn.AvgPool3d, nn.MaxPool3d, and None (not performing pooling).
+ pool_kernel_size (tuple): pooling kernel size(s).
+ pool_stride (tuple): pooling stride size(s).
+ pool_padding (tuple): pooling padding size(s).
+
+ norm (callable): a callable that constructs normalization layer, options
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer, options
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): resnet basic stem layer.
+ """
+ conv_module = conv(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=conv_kernel_size,
+ stride=conv_stride,
+ padding=conv_padding,
+ bias=conv_bias,
+ )
+ norm_module = (
+ None
+ if norm is None
+ else norm(num_features=out_channels, eps=norm_eps, momentum=norm_momentum)
+ )
+ activation_module = None if activation is None else activation()
+ pool_module = (
+ None
+ if pool is None
+ else pool(
+ kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
+ )
+ )
+
+ return ResNetBasicStem(
+ conv=conv_module,
+ norm=norm_module,
+ activation=activation_module,
+ pool=pool_module,
+ )
+
+
+def create_acoustic_res_basic_stem(
+ *,
+ # Conv configs.
+ in_channels: int,
+ out_channels: int,
+ conv_kernel_size: Tuple[int] = (3, 7, 7),
+ conv_stride: Tuple[int] = (1, 1, 1),
+ conv_padding: Tuple[int] = (1, 3, 3),
+ conv_bias: bool = False,
+ # Pool configs.
+ pool: Callable = nn.MaxPool3d,
+ pool_kernel_size: Tuple[int] = (1, 3, 3),
+ pool_stride: Tuple[int] = (1, 2, 2),
+ pool_padding: Tuple[int] = (0, 1, 1),
+ # BN configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Creates the acoustic resnet stem layer. It performs a spatial and a temporal
+ Convolution in parallel, then performs, BN, and Relu following by a spatiotemporal
+ pooling.
+
+ ::
+
+ Conv3d Conv3d
+ ↓
+ Normalization
+ ↓
+ Activation
+ ↓
+ Pool3d
+
+ Normalization options include: BatchNorm3d and None (no normalization).
+ Activation options include: ReLU, Softmax, Sigmoid, and None (no activation).
+ Pool3d options include: AvgPool3d, MaxPool3d, and None (no pooling).
+
+ Args:
+ in_channels (int): input channel size of the convolution.
+ out_channels (int): output channel size of the convolution.
+ conv_kernel_size (tuple): convolutional kernel size(s).
+ conv_stride (tuple): convolutional stride size(s), it will be performed as
+ temporal and spatial convolution in parallel.
+ conv_padding (tuple): convolutional padding size(s), it will be performed
+ as temporal and spatial convolution in parallel.
+ conv_bias (bool): convolutional bias. If true, adds a learnable bias to the
+ output.
+
+ pool (callable): a callable that constructs pooling layer, options include:
+ nn.AvgPool3d, nn.MaxPool3d, and None (not performing pooling).
+ pool_kernel_size (tuple): pooling kernel size(s).
+ pool_stride (tuple): pooling stride size(s).
+ pool_padding (tuple): pooling padding size(s).
+
+ norm (callable): a callable that constructs normalization layer, options
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer, options
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): resnet basic stem layer.
+ """
+ conv_module = ConvReduce3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(
+ # Temporal conv kernel size.
+ (conv_kernel_size[0], 1, 1),
+ # Spatial conv kernel size.
+ (1, conv_kernel_size[1], conv_kernel_size[2]),
+ ),
+ stride=(conv_stride, conv_stride),
+ padding=((conv_padding[0], 0, 0), (0, conv_padding[1], conv_padding[2])),
+ bias=(conv_bias, conv_bias),
+ reduction_method="sum",
+ )
+ norm_module = (
+ None
+ if norm is None
+ else norm(num_features=out_channels, eps=norm_eps, momentum=norm_momentum)
+ )
+ activation_module = None if activation is None else activation()
+ pool_module = (
+ None
+ if pool is None
+ else pool(
+ kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding
+ )
+ )
+
+ return ResNetBasicStem(
+ conv=conv_module,
+ norm=norm_module,
+ activation=activation_module,
+ pool=pool_module,
+ )
+
+
+class ResNetBasicStem(nn.Module):
+ """
+ ResNet basic 3D stem module. Performs spatiotemporal Convolution, BN, and activation
+ following by a spatiotemporal pooling.
+
+ ::
+
+ Conv3d
+ ↓
+ Normalization
+ ↓
+ Activation
+ ↓
+ Pool3d
+
+ The builder can be found in `create_res_basic_stem`.
+ """
+
+ def __init__(
+ self,
+ *,
+ conv: nn.Module = None,
+ norm: nn.Module = None,
+ activation: nn.Module = None,
+ pool: nn.Module = None,
+ ) -> None:
+ """
+ Args:
+ conv (torch.nn.modules): convolutional module.
+ norm (torch.nn.modules): normalization module.
+ activation (torch.nn.modules): activation module.
+ pool (torch.nn.modules): pooling module.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.conv is not None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ if self.pool is not None:
+ x = self.pool(x)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Transformer basic patch embedding module. Performs patchifying input, flatten and
+ and transpose.
+
+ ::
+
+ PatchModel
+ ↓
+ flatten
+ ↓
+ transpose
+
+ The builder can be found in `create_patch_embed`.
+
+ """
+
+ def __init__(
+ self,
+ *,
+ patch_model: nn.Module = None,
+ ) -> None:
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.patch_model is not None
+
+ def forward(self, x) -> torch.Tensor:
+ x = self.patch_model(x)
+ # B C (T) H W -> B (T)HW C
+ return x.flatten(2).transpose(1, 2)
+
+
+def create_conv_patch_embed(
+ *,
+ in_channels: int,
+ out_channels: int,
+ conv_kernel_size: Tuple[int] = (1, 16, 16),
+ conv_stride: Tuple[int] = (1, 4, 4),
+ conv_padding: Tuple[int] = (1, 7, 7),
+ conv_bias: bool = True,
+ conv: Callable = nn.Conv3d,
+) -> nn.Module:
+ """
+ Creates the transformer basic patch embedding. It performs Convolution, flatten and
+ transpose.
+
+ ::
+
+ Conv3d
+ ↓
+ flatten
+ ↓
+ transpose
+
+ Args:
+ in_channels (int): input channel size of the convolution.
+ out_channels (int): output channel size of the convolution.
+ conv_kernel_size (tuple): convolutional kernel size(s).
+ conv_stride (tuple): convolutional stride size(s).
+ conv_padding (tuple): convolutional padding size(s).
+ conv_bias (bool): convolutional bias. If true, adds a learnable bias to the
+ output.
+ conv (callable): Callable used to build the convolution layer.
+
+ Returns:
+ (nn.Module): transformer patch embedding layer.
+ """
+ conv_module = conv(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=conv_kernel_size,
+ stride=conv_stride,
+ padding=conv_padding,
+ bias=conv_bias,
+ )
+ return PatchEmbed(patch_model=conv_module)
diff --git a/code/pytorchvideo/pytorchvideo/models/vision_transformers.py b/code/pytorchvideo/pytorchvideo/models/vision_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..84e94d6a5f9f066ea7d47407476f2a0fd9c45d4d
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/vision_transformers.py
@@ -0,0 +1,506 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import warnings
+from functools import partial
+from typing import Callable, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers import MultiScaleBlock, SpatioTemporalClsPositionalEncoding
+from pytorchvideo.layers.utils import round_width, set_attributes
+from pytorchvideo.models.head import create_vit_basic_head
+from pytorchvideo.models.weight_init import init_net_weights
+from torch.nn.common_types import _size_2_t, _size_3_t
+
+from .stem import create_conv_patch_embed
+
+
+class MultiscaleVisionTransformers(nn.Module):
+ """
+ Multiscale Vision Transformers
+ Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik,
+ Christoph Feichtenhofer
+ https://arxiv.org/abs/2104.11227
+
+ ::
+
+ PatchEmbed
+ ↓
+ PositionalEncoding
+ ↓
+ Dropout
+ ↓
+ Normalization
+ ↓
+ Block 1
+ ↓
+ .
+ .
+ .
+ ↓
+ Block N
+ ↓
+ Normalization
+ ↓
+ Head
+
+
+ The builder can be found in `create_mvit`.
+ """
+
+ def __init__(
+ self,
+ *,
+ patch_embed: Optional[nn.Module],
+ cls_positional_encoding: nn.Module,
+ pos_drop: Optional[nn.Module],
+ blocks: nn.ModuleList,
+ norm_embed: Optional[nn.Module],
+ head: Optional[nn.Module],
+ ) -> None:
+ """
+ Args:
+ patch_embed (nn.Module): Patch embed module.
+ cls_positional_encoding (nn.Module): Positional encoding module.
+ pos_drop (Optional[nn.Module]): Dropout module after patch embed.
+ blocks (nn.ModuleList): Stack of multi-scale transformer blocks.
+ norm_layer (nn.Module): Normalization layer before head.
+ head (Optional[nn.Module]): Head module.
+ """
+ super().__init__()
+
+ assert hasattr(
+ cls_positional_encoding, "patch_embed_shape"
+ ), "cls_positional_encoding should have method patch_embed_shape."
+
+ self.patch_embed = patch_embed or torch.nn.Identity()
+ self.cls_positional_encoding = cls_positional_encoding
+ self.pos_drop = pos_drop or torch.nn.Identity()
+ self.blocks = blocks
+ self.norm_embed = norm_embed or torch.nn.Identity()
+ self.head = head or torch.nn.Identity()
+
+ init_net_weights(self, init_std=0.02, style="vit")
+
+ def _get_bn_w_b(self, bn, repeat=1):
+ w_bn = torch.diag(
+ bn.weight.div(torch.sqrt(bn.eps + bn.running_var)).repeat(repeat)
+ )
+
+ b_bn = (
+ bn.bias
+ - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+ ).repeat(repeat)
+ return w_bn, b_bn
+
+ def fuse_norm_before_linear(self, bn, linear):
+ if bn is None:
+ return linear
+ w_bn, b_bn = self._get_bn_w_b(bn)
+ fused_linear = nn.Linear(linear.in_features, linear.out_features, bias=True)
+ fused_linear.weight.data[:] = torch.mm(linear.weight, w_bn)
+ fused_linear.bias.data[:] = (
+ torch.matmul(linear.weight, b_bn) + linear.bias
+ if linear.bias is not None
+ else torch.matmul(linear.weight, b_bn)
+ )
+ return fused_linear
+
+ def fuse_norm_after_linear(self, linear, bn):
+ if bn is None:
+ return linear
+ assert linear.in_features % bn.bias.shape[0] == 0
+ num_heads = linear.in_features // bn.bias.shape[0]
+ w_bn, b_bn = self._get_bn_w_b(bn, repeat=num_heads)
+
+ fused_linear = nn.Linear(linear.in_features, linear.out_features, bias=True)
+ fused_linear.weight.data[:] = torch.mm(w_bn, linear.weight)
+ fused_linear.bias.data[:] = (
+ torch.matmul(w_bn, linear.bias) + b_bn if linear.bias is not None else b_bn
+ )
+ return fused_linear
+
+ def fuse_bn(self):
+ assert not self.training
+ for blk in self.blocks:
+ # fuse self.norm1
+ if blk.attn.separate_qkv:
+ blk.attn.q = self.fuse_norm_before_linear(blk.norm1, blk.attn.q)
+ blk.attn.k = self.fuse_norm_before_linear(blk.norm1, blk.attn.k)
+ blk.attn.v = self.fuse_norm_before_linear(blk.norm1, blk.attn.v)
+ else:
+ blk.attn.qkv = self.fuse_norm_before_linear(blk.norm1, blk.attn.qkv)
+ blk.norm1 = nn.Identity()
+
+ # fuse the bn in attention
+ if blk.attn.separate_qkv:
+ blk.attn.q = self.fuse_norm_after_linear(blk.attn.q, blk.attn.norm_q)
+ blk.attn.k = self.fuse_norm_after_linear(blk.attn.k, blk.attn.norm_k)
+ blk.attn.v = self.fuse_norm_after_linear(blk.attn.v, blk.attn.norm_v)
+ else:
+ w_q, w_k, w_v = blk.attn.qkv.weight.chunk(3)
+ b_q, b_k, b_v = blk.attn.qkv.bias.chunk(3)
+ tmp_q = nn.Linear(w_q.shape[1], w_q.shape[0], bias=True)
+ tmp_k = nn.Linear(w_k.shape[1], w_k.shape[0], bias=True)
+ tmp_v = nn.Linear(w_v.shape[1], w_v.shape[0], bias=True)
+ tmp_q.weight.data[:] = w_q
+ tmp_k.weight.data[:] = w_k
+ tmp_v.weight.data[:] = w_v
+ tmp_q.bias.data[:] = b_q
+ tmp_k.bias.data[:] = b_k
+ tmp_v.bias.data[:] = b_v
+ tmp_q = self.fuse_norm_after_linear(tmp_q, blk.attn.norm_q)
+ tmp_k = self.fuse_norm_after_linear(tmp_k, blk.attn.norm_k)
+ tmp_v = self.fuse_norm_after_linear(tmp_v, blk.attn.norm_v)
+ blk.attn.qkv.weight.data[:] = torch.cat(
+ [tmp_q.weight.data, tmp_k.weight.data, tmp_v.weight.data], dim=0
+ )
+ blk.attn.qkv.bias.data[:] = torch.cat(
+ [tmp_q.bias.data, tmp_k.bias.data, tmp_v.bias.data], dim=0
+ )
+
+ blk.attn.norm_q = nn.Identity()
+ blk.attn.norm_k = nn.Identity()
+ blk.attn.norm_v = nn.Identity()
+
+ # fuse self.norm2
+ blk.mlp.fc1 = self.fuse_norm_before_linear(blk.norm2, blk.mlp.fc1)
+ if blk.dim != blk.dim_out:
+ blk.proj = self.fuse_norm_before_linear(blk.norm2, blk.proj)
+ blk.norm2 = nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ x = self.cls_positional_encoding(x)
+ x = self.pos_drop(x)
+
+ thw = self.cls_positional_encoding.patch_embed_shape()
+ for blk in self.blocks:
+ x, thw = blk(x, thw)
+ x = self.norm_embed(x)
+ x = self.head(x)
+ return x
+
+
+def create_multiscale_vision_transformers(
+ *,
+ spatial_size: _size_2_t,
+ temporal_size: int,
+ cls_embed_on: bool = True,
+ sep_pos_embed: bool = True,
+ depth: int = 16,
+ norm: str = "layernorm",
+ # Patch embed config.
+ enable_patch_embed: bool = True,
+ input_channels: int = 3,
+ patch_embed_dim: int = 96,
+ conv_patch_embed_kernel: Tuple[int] = (3, 7, 7),
+ conv_patch_embed_stride: Tuple[int] = (2, 4, 4),
+ conv_patch_embed_padding: Tuple[int] = (1, 3, 3),
+ enable_patch_embed_norm: bool = False,
+ use_2d_patch: bool = False,
+ # Attention block config.
+ num_heads: int = 1,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ dropout_rate_block: float = 0.0,
+ droppath_rate_block: float = 0.0,
+ pooling_mode: str = "conv",
+ pool_first: bool = False,
+ residual_pool: bool = False,
+ depthwise_conv: bool = True,
+ bias_on: bool = True,
+ separate_qkv: bool = True,
+ embed_dim_mul: Optional[List[List[int]]] = None,
+ atten_head_mul: Optional[List[List[int]]] = None,
+ dim_mul_in_att: bool = False,
+ pool_q_stride_size: Optional[List[List[int]]] = None,
+ pool_kv_stride_size: Optional[List[List[int]]] = None,
+ pool_kv_stride_adaptive: Optional[_size_3_t] = None,
+ pool_kvq_kernel: Optional[_size_3_t] = None,
+ # Head config.
+ head: Optional[Callable] = create_vit_basic_head,
+ head_dropout_rate: float = 0.5,
+ head_activation: Callable = None,
+ head_num_classes: int = 400,
+ # The default model definition is not TorchScript-friendly.
+ # Set create_scriptable_model=True to create a TorchScriptable model.
+ create_scriptable_model: bool = False,
+ multiscale_vit_class: Callable = MultiscaleVisionTransformers,
+) -> nn.Module:
+ """
+ Build Multiscale Vision Transformers (MViT) for recognition. A Vision Transformer
+ (ViT) is a specific case of MViT that only uses a single scale attention block.
+
+ Args:
+ spatial_size (_size_2_t): Input video spatial resolution (H, W). If a single
+ int is given, it assumes the width and the height are the same.
+ temporal_size (int): Number of frames in the input video.
+ cls_embed_on (bool): If True, use cls embed in the model. Otherwise features
+ are average pooled before going to the final classifier.
+ sep_pos_embed (bool): If True, perform separate spatiotemporal embedding.
+ depth (int): The depth of the model.
+ norm (str): Normalization layer. It currently supports "layernorm".
+
+ enable_patch_embed (bool): If true, patchify the input video. If false, it
+ assumes the input should have the feature dimension of patch_embed_dim.
+ input_channels (int): Channel dimension of the input video.
+ patch_embed_dim (int): Embedding dimension after patchifing the video input.
+ conv_patch_embed_kernel (Tuple[int]): Kernel size of the convolution for
+ patchifing the video input.
+ conv_patch_embed_stride (Tuple[int]): Stride size of the convolution for
+ patchifing the video input.
+ conv_patch_embed_padding (Tuple[int]): Padding size of the convolution for
+ patchifing the video input.
+ enable_patch_embed_norm (bool): If True, apply normalization after patchifing
+ the video input.
+ use_2d_patch (bool): If True, use 2D convolutions to get patch embed.
+ Otherwise, use 3D convolutions.
+
+ num_heads (int): Number of heads in the first transformer block.
+ mlp_ratio (float): Mlp ratio which controls the feature dimension in the
+ hidden layer of the Mlp block.
+ qkv_bias (bool): If set to False, the qkv layer will not learn an additive
+ bias. Default: True.
+ dropout_rate_block (float): Dropout rate for the attention block.
+ droppath_rate_block (float): Droppath rate for the attention block.
+ pooling_mode (str): Pooling mode. Option includes "conv" (learned pooling), "avg"
+ (average pooling), and "max" (max pooling).
+ pool_first (bool): If set to True, pool is applied before qkv projection.
+ Otherwise, pool is applied after qkv projection. Default: False.
+ residual_pool (bool): If set to True, use Improved Multiscale Vision
+ Transformer's pooling residual connection.
+ depthwise_conv (bool): Whether use depthwise or full convolution for pooling.
+ bias_on (bool): Whether use biases for linear layers.
+ separate_qkv (bool): Whether to use separate or one layer for qkv projections.
+ embed_dim_mul (Optional[List[List[int]]]): Dimension multiplication at layer i.
+ If X is used, then the next block will increase the embed dimension by X
+ times. Format: [depth_i, mul_dim_ratio].
+ atten_head_mul (Optional[List[List[int]]]): Head dimension multiplication at
+ layer i. If X is used, then the next block will increase the head by
+ X times. Format: [depth_i, mul_dim_ratio].
+ dim_mul_in_att (bool): If set to True, dimension expansion happens inside
+ the attention module, otherwise it happens in the Mlp block. Default: False.
+ pool_q_stride_size (Optional[List[List[int]]]): List of stride sizes for the
+ pool q at each layer. Format:
+ [[i, stride_t_i, stride_h_i, stride_w_i], ...,].
+ pool_kv_stride_size (Optional[List[List[int]]]): List of stride sizes for the
+ pool kv at each layer. Format:
+ [[i, stride_t_i, stride_h_i, stride_w_i], ...,].
+ pool_kv_stride_adaptive (Optional[_size_3_t]): Initial kv stride size for the
+ first block. The stride size will be further reduced at the layer where q
+ is pooled with the ratio of the stride of q pooling. If
+ pool_kv_stride_adaptive is set, then pool_kv_stride_size should be none.
+ pool_kvq_kernel (Optional[_size_3_t]): Pooling kernel size for q and kv. It None,
+ the kernel_size is [s + 1 if s > 1 else s for s in stride_size].
+
+ head (Callable): Head model.
+ head_dropout_rate (float): Dropout rate in the head.
+ head_activation (Callable): Activation in the head.
+ head_num_classes (int): Number of classes in the final classification head.
+ multiscale_vit_class (Callable): MViT transformer class. Default to
+ MultiscaleVisionTransformers.
+
+ Example usage (building a MViT_B model for Kinetics400):
+
+ spatial_size = 224
+ temporal_size = 16
+ embed_dim_mul = [[1, 2.0], [3, 2.0], [14, 2.0]]
+ atten_head_mul = [[1, 2.0], [3, 2.0], [14, 2.0]]
+ pool_q_stride_size = [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]]
+ pool_kv_stride_adaptive = [1, 8, 8]
+ pool_kvq_kernel = [3, 3, 3]
+ head_num_classes = 400
+ MViT_B = create_multiscale_vision_transformers(
+ spatial_size=spatial_size,
+ temporal_size=temporal_size,
+ embed_dim_mul=embed_dim_mul,
+ atten_head_mul=atten_head_mul,
+ pool_q_stride_size=pool_q_stride_size,
+ pool_kv_stride_adaptive=pool_kv_stride_adaptive,
+ pool_kvq_kernel=pool_kvq_kernel,
+ head_num_classes=head_num_classes,
+ )
+ """
+
+ if use_2d_patch:
+ assert temporal_size == 1, "If use_2d_patch, temporal_size needs to be 1."
+ if pool_kv_stride_adaptive is not None:
+ assert (
+ pool_kv_stride_size is None
+ ), "pool_kv_stride_size should be none if pool_kv_stride_adaptive is set."
+ if norm == "layernorm":
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ block_norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ attn_norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ elif norm == "batchnorm":
+ norm_layer = None
+ block_norm_layer = nn.BatchNorm1d
+ attn_norm_layer = nn.BatchNorm3d
+ else:
+ raise NotImplementedError("Only supports layernorm.")
+ if create_scriptable_model:
+ assert (
+ norm == "batchnorm"
+ ), "The scriptable model supports only the batchnorm-based model."
+ warnings.warn(
+ "`create_scriptable_model` is deprecated. MultiscaleVisionTransformers"
+ " now supports scripting without this flag.",
+ DeprecationWarning,
+ )
+
+ if isinstance(spatial_size, int):
+ spatial_size = (spatial_size, spatial_size)
+
+ conv_patch_op = nn.Conv2d if use_2d_patch else nn.Conv3d
+
+ patch_embed = (
+ create_conv_patch_embed(
+ in_channels=input_channels,
+ out_channels=patch_embed_dim,
+ conv_kernel_size=conv_patch_embed_kernel,
+ conv_stride=conv_patch_embed_stride,
+ conv_padding=conv_patch_embed_padding,
+ conv=conv_patch_op,
+ )
+ if enable_patch_embed
+ else None
+ )
+
+ input_dims = [temporal_size, spatial_size[0], spatial_size[1]]
+ input_stride = (
+ (1,) + tuple(conv_patch_embed_stride)
+ if use_2d_patch
+ else conv_patch_embed_stride
+ )
+
+ patch_embed_shape = (
+ [input_dims[i] // input_stride[i] for i in range(len(input_dims))]
+ if enable_patch_embed
+ else input_dims
+ )
+
+ cls_positional_encoding = SpatioTemporalClsPositionalEncoding(
+ embed_dim=patch_embed_dim,
+ patch_embed_shape=patch_embed_shape,
+ sep_pos_embed=sep_pos_embed,
+ has_cls=cls_embed_on,
+ )
+
+ dpr = [
+ x.item() for x in torch.linspace(0, droppath_rate_block, depth)
+ ] # stochastic depth decay rule
+
+ if dropout_rate_block > 0.0:
+ pos_drop = nn.Dropout(p=dropout_rate_block)
+
+ dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
+ if embed_dim_mul is not None:
+ for i in range(len(embed_dim_mul)):
+ dim_mul[embed_dim_mul[i][0]] = embed_dim_mul[i][1]
+ if atten_head_mul is not None:
+ for i in range(len(atten_head_mul)):
+ head_mul[atten_head_mul[i][0]] = atten_head_mul[i][1]
+
+ mvit_blocks = nn.ModuleList()
+
+ pool_q = [[] for i in range(depth)]
+ pool_kv = [[] for i in range(depth)]
+ stride_q = [[] for i in range(depth)]
+ stride_kv = [[] for i in range(depth)]
+
+ if pool_q_stride_size is not None:
+ for i in range(len(pool_q_stride_size)):
+ stride_q[pool_q_stride_size[i][0]] = pool_q_stride_size[i][1:]
+ if pool_kvq_kernel is not None:
+ pool_q[pool_q_stride_size[i][0]] = pool_kvq_kernel
+ else:
+ pool_q[pool_q_stride_size[i][0]] = [
+ s + 1 if s > 1 else s for s in pool_q_stride_size[i][1:]
+ ]
+
+ # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE.
+ if pool_kv_stride_adaptive is not None:
+ _stride_kv = pool_kv_stride_adaptive
+ pool_kv_stride_size = []
+ for i in range(depth):
+ if len(stride_q[i]) > 0:
+ _stride_kv = [
+ max(_stride_kv[d] // stride_q[i][d], 1)
+ for d in range(len(_stride_kv))
+ ]
+ pool_kv_stride_size.append([i] + _stride_kv)
+
+ if pool_kv_stride_size is not None:
+ for i in range(len(pool_kv_stride_size)):
+ stride_kv[pool_kv_stride_size[i][0]] = pool_kv_stride_size[i][1:]
+ if pool_kvq_kernel is not None:
+ pool_kv[pool_kv_stride_size[i][0]] = pool_kvq_kernel
+ else:
+ pool_kv[pool_kv_stride_size[i][0]] = [
+ s + 1 if s > 1 else s for s in pool_kv_stride_size[i][1:]
+ ]
+
+ dim_in = patch_embed_dim
+ for i in range(depth):
+ num_heads = round_width(num_heads, head_mul[i], min_width=1, divisor=1)
+ if dim_mul_in_att:
+ dim_out = round_width(
+ dim_in,
+ dim_mul[i],
+ divisor=round_width(num_heads, head_mul[i]),
+ )
+ else:
+ dim_out = round_width(
+ dim_in,
+ dim_mul[i + 1],
+ divisor=round_width(num_heads, head_mul[i + 1]),
+ )
+
+ mvit_blocks.append(
+ MultiScaleBlock(
+ dim=dim_in,
+ dim_out=dim_out,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ dropout_rate=dropout_rate_block,
+ droppath_rate=dpr[i],
+ norm_layer=block_norm_layer,
+ attn_norm_layer=attn_norm_layer,
+ dim_mul_in_att=dim_mul_in_att,
+ kernel_q=pool_q[i],
+ kernel_kv=pool_kv[i],
+ stride_q=stride_q[i],
+ stride_kv=stride_kv[i],
+ pool_mode=pooling_mode,
+ has_cls_embed=cls_embed_on,
+ pool_first=pool_first,
+ residual_pool=residual_pool,
+ bias_on=bias_on,
+ depthwise_conv=depthwise_conv,
+ separate_qkv=separate_qkv,
+ )
+ )
+ dim_in = dim_out
+
+ norm_embed = None if norm_layer is None else norm_layer(dim_in)
+ if head is not None:
+ head_model = head(
+ in_features=dim_in,
+ out_features=head_num_classes,
+ seq_pool_type="cls" if cls_embed_on else "mean",
+ dropout_rate=head_dropout_rate,
+ activation=head_activation,
+ )
+ else:
+ head_model = None
+
+ return multiscale_vit_class(
+ patch_embed=patch_embed,
+ cls_positional_encoding=cls_positional_encoding,
+ pos_drop=pos_drop if dropout_rate_block > 0.0 else None,
+ blocks=mvit_blocks,
+ norm_embed=norm_embed,
+ head=head_model,
+ )
diff --git a/code/pytorchvideo/pytorchvideo/models/weight_init.py b/code/pytorchvideo/pytorchvideo/models/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ef196b26bff83e00e9c10d76387b54889f2e7b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/weight_init.py
@@ -0,0 +1,92 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import torch.nn as nn
+from fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill
+from pytorchvideo.layers import SpatioTemporalClsPositionalEncoding
+
+
+def _init_resnet_weights(model: nn.Module, fc_init_std: float = 0.01) -> None:
+ """
+ Performs ResNet style weight initialization. That is, recursively initialize the
+ given model in the following way for each type:
+ Conv - Follow the initialization of kaiming_normal:
+ https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
+ BatchNorm - Set weight and bias of last BatchNorm at every residual bottleneck
+ to 0.
+ Linear - Set weight to 0 mean Gaussian with std deviation fc_init_std and bias
+ to 0.
+ Args:
+ model (nn.Module): Model to be initialized.
+ fc_init_std (float): the expected standard deviation for fully-connected layer.
+ """
+ for m in model.modules():
+ if isinstance(m, (nn.Conv2d, nn.Conv3d)):
+ """
+ Follow the initialization method proposed in:
+ {He, Kaiming, et al.
+ "Delving deep into rectifiers: Surpassing human-level
+ performance on imagenet classification."
+ arXiv preprint arXiv:1502.01852 (2015)}
+ """
+ c2_msra_fill(m)
+ elif isinstance(m, nn.modules.batchnorm._NormBase):
+ if m.weight is not None:
+ if hasattr(m, "block_final_bn") and m.block_final_bn:
+ m.weight.data.fill_(0.0)
+ else:
+ m.weight.data.fill_(1.0)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ if isinstance(m, nn.Linear):
+ if hasattr(m, "xavier_init") and m.xavier_init:
+ c2_xavier_fill(m)
+ else:
+ m.weight.data.normal_(mean=0.0, std=fc_init_std)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ return model
+
+
+def _init_vit_weights(model: nn.Module, trunc_normal_std: float = 0.02) -> None:
+ """
+ Weight initialization for vision transformers.
+
+ Args:
+ model (nn.Module): Model to be initialized.
+ trunc_normal_std (float): the expected standard deviation for fully-connected
+ layer and ClsPositionalEncoding.
+ """
+ for m in model.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=trunc_normal_std)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, SpatioTemporalClsPositionalEncoding):
+ for weights in m.parameters():
+ nn.init.trunc_normal_(weights, std=trunc_normal_std)
+
+
+def init_net_weights(
+ model: nn.Module,
+ init_std: float = 0.01,
+ style: str = "resnet",
+) -> None:
+ """
+ Performs weight initialization. Options include ResNet style weight initialization
+ and transformer style weight initialization.
+
+ Args:
+ model (nn.Module): Model to be initialized.
+ init_std (float): The expected standard deviation for initialization.
+ style (str): Options include "resnet" and "vit".
+ """
+ assert style in ["resnet", "vit"]
+ if style == "resnet":
+ return _init_resnet_weights(model, init_std)
+ elif style == "vit":
+ return _init_vit_weights(model, init_std)
+ else:
+ raise NotImplementedError
diff --git a/code/pytorchvideo/pytorchvideo/models/x3d.py b/code/pytorchvideo/pytorchvideo/models/x3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..88f0d0aeb0ee80091cf477e2cb81f0e0dcace880
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/models/x3d.py
@@ -0,0 +1,804 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import math
+from typing import Callable, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+from fvcore.nn.squeeze_excitation import SqueezeExcitation
+from pytorchvideo.layers.convolutions import Conv2plus1d
+from pytorchvideo.layers.swish import Swish
+from pytorchvideo.layers.utils import round_repeats, round_width, set_attributes
+from pytorchvideo.models.head import ResNetBasicHead
+from pytorchvideo.models.net import Net
+from pytorchvideo.models.resnet import BottleneckBlock, ResBlock, ResStage
+from pytorchvideo.models.stem import ResNetBasicStem
+
+
+def create_x3d_stem(
+ *,
+ # Conv configs.
+ in_channels: int,
+ out_channels: int,
+ conv_kernel_size: Tuple[int] = (5, 3, 3),
+ conv_stride: Tuple[int] = (1, 2, 2),
+ conv_padding: Tuple[int] = (2, 1, 1),
+ # BN configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+) -> nn.Module:
+ """
+ Creates the stem layer for X3D. It performs spatial Conv, temporal Conv, BN, and Relu.
+
+ ::
+
+ Conv_xy
+ ↓
+ Conv_t
+ ↓
+ Normalization
+ ↓
+ Activation
+
+ Args:
+ in_channels (int): input channel size of the convolution.
+ out_channels (int): output channel size of the convolution.
+ conv_kernel_size (tuple): convolutional kernel size(s).
+ conv_stride (tuple): convolutional stride size(s).
+ conv_padding (tuple): convolutional padding size(s).
+
+ norm (callable): a callable that constructs normalization layer, options
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer, options
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+
+ Returns:
+ (nn.Module): X3D stem layer.
+ """
+ conv_xy_module = nn.Conv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(1, conv_kernel_size[1], conv_kernel_size[2]),
+ stride=(1, conv_stride[1], conv_stride[2]),
+ padding=(0, conv_padding[1], conv_padding[2]),
+ bias=False,
+ )
+ conv_t_module = nn.Conv3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(conv_kernel_size[0], 1, 1),
+ stride=(conv_stride[0], 1, 1),
+ padding=(conv_padding[0], 0, 0),
+ bias=False,
+ groups=out_channels,
+ )
+ stacked_conv_module = Conv2plus1d(
+ conv_t=conv_xy_module,
+ norm=None,
+ activation=None,
+ conv_xy=conv_t_module,
+ )
+
+ norm_module = (
+ None
+ if norm is None
+ else norm(num_features=out_channels, eps=norm_eps, momentum=norm_momentum)
+ )
+ activation_module = None if activation is None else activation()
+
+ return ResNetBasicStem(
+ conv=stacked_conv_module,
+ norm=norm_module,
+ activation=activation_module,
+ pool=None,
+ )
+
+
+def create_x3d_bottleneck_block(
+ *,
+ # Convolution configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ conv_kernel_size: Tuple[int] = (3, 3, 3),
+ conv_stride: Tuple[int] = (1, 2, 2),
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ se_ratio: float = 0.0625,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ inner_act: Callable = Swish,
+) -> nn.Module:
+ """
+ Bottleneck block for X3D: a sequence of Conv, Normalization with optional SE block,
+ and Activations repeated in the following order:
+
+ ::
+
+ Conv3d (conv_a)
+ ↓
+ Normalization (norm_a)
+ ↓
+ Activation (act_a)
+ ↓
+ Conv3d (conv_b)
+ ↓
+ Normalization (norm_b)
+ ↓
+ Squeeze-and-Excitation
+ ↓
+ Activation (act_b)
+ ↓
+ Conv3d (conv_c)
+ ↓
+ Normalization (norm_c)
+
+ Args:
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_stride (tuple): convolutional stride size(s) for conv_b.
+
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+ se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
+ channel dimensionality being se_ratio times the 3x3x3 conv dim.
+
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+ inner_act (callable): whether use Swish activation for act_b or not.
+
+ Returns:
+ (nn.Module): X3D bottleneck block.
+ """
+ # 1x1x1 Conv
+ conv_a = nn.Conv3d(
+ in_channels=dim_in, out_channels=dim_inner, kernel_size=(1, 1, 1), bias=False
+ )
+ norm_a = (
+ None
+ if norm is None
+ else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ )
+ act_a = None if activation is None else activation()
+
+ # 3x3x3 Conv
+ conv_b = nn.Conv3d(
+ in_channels=dim_inner,
+ out_channels=dim_inner,
+ kernel_size=conv_kernel_size,
+ stride=conv_stride,
+ padding=[size // 2 for size in conv_kernel_size],
+ bias=False,
+ groups=dim_inner,
+ dilation=(1, 1, 1),
+ )
+ se = (
+ SqueezeExcitation(
+ num_channels=dim_inner,
+ num_channels_reduced=round_width(dim_inner, se_ratio),
+ is_3d=True,
+ )
+ if se_ratio > 0.0
+ else nn.Identity()
+ )
+ norm_b = nn.Sequential(
+ (
+ nn.Identity()
+ if norm is None
+ else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ ),
+ se,
+ )
+ act_b = None if inner_act is None else inner_act()
+
+ # 1x1x1 Conv
+ conv_c = nn.Conv3d(
+ in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
+ )
+ norm_c = (
+ None
+ if norm is None
+ else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
+ )
+
+ return BottleneckBlock(
+ conv_a=conv_a,
+ norm_a=norm_a,
+ act_a=act_a,
+ conv_b=conv_b,
+ norm_b=norm_b,
+ act_b=act_b,
+ conv_c=conv_c,
+ norm_c=norm_c,
+ )
+
+
+def create_x3d_res_block(
+ *,
+ # Bottleneck Block configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ bottleneck: Callable = create_x3d_bottleneck_block,
+ use_shortcut: bool = True,
+ # Conv configs.
+ conv_kernel_size: Tuple[int] = (3, 3, 3),
+ conv_stride: Tuple[int] = (1, 2, 2),
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ se_ratio: float = 0.0625,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ inner_act: Callable = Swish,
+) -> nn.Module:
+ """
+ Residual block for X3D. Performs a summation between an identity shortcut in branch1 and a
+ main block in branch2. When the input and output dimensions are different, a
+ convolution followed by a normalization will be performed.
+
+ ::
+
+ Input
+ |-------+
+ ↓ |
+ Block |
+ ↓ |
+ Summation ←-+
+ ↓
+ Activation
+
+ Args:
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ bottleneck (callable): a callable for create_x3d_bottleneck_block.
+
+ conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_stride (tuple): convolutional stride size(s) for conv_b.
+
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+ se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
+ channel dimensionality being se_ratio times the 3x3x3 conv dim.
+
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+ inner_act (callable): whether use Swish activation for act_b or not.
+
+ Returns:
+ (nn.Module): X3D block layer.
+ """
+
+ norm_model = None
+ if norm is not None and dim_in != dim_out:
+ norm_model = norm(num_features=dim_out)
+
+ return ResBlock(
+ branch1_conv=nn.Conv3d(
+ dim_in,
+ dim_out,
+ kernel_size=(1, 1, 1),
+ stride=conv_stride,
+ bias=False,
+ )
+ if (dim_in != dim_out or np.prod(conv_stride) > 1) and use_shortcut
+ else None,
+ branch1_norm=norm_model if dim_in != dim_out and use_shortcut else None,
+ branch2=bottleneck(
+ dim_in=dim_in,
+ dim_inner=dim_inner,
+ dim_out=dim_out,
+ conv_kernel_size=conv_kernel_size,
+ conv_stride=conv_stride,
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ se_ratio=se_ratio,
+ activation=activation,
+ inner_act=inner_act,
+ ),
+ activation=None if activation is None else activation(),
+ branch_fusion=lambda x, y: x + y,
+ )
+
+
+def create_x3d_res_stage(
+ *,
+ # Stage configs.
+ depth: int,
+ # Bottleneck Block configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ bottleneck: Callable = create_x3d_bottleneck_block,
+ # Conv configs.
+ conv_kernel_size: Tuple[int] = (3, 3, 3),
+ conv_stride: Tuple[int] = (1, 2, 2),
+ # Norm configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ se_ratio: float = 0.0625,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ inner_act: Callable = Swish,
+) -> nn.Module:
+ """
+ Create Residual Stage, which composes sequential blocks that make up X3D.
+
+ ::
+
+ Input
+ ↓
+ ResBlock
+ ↓
+ .
+ .
+ .
+ ↓
+ ResBlock
+
+ Args:
+
+ depth (init): number of blocks to create.
+
+ dim_in (int): input channel size to the bottleneck block.
+ dim_inner (int): intermediate channel size of the bottleneck.
+ dim_out (int): output channel size of the bottleneck.
+ bottleneck (callable): a callable for create_x3d_bottleneck_block.
+
+ conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ conv_stride (tuple): convolutional stride size(s) for conv_b.
+
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+ se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
+ channel dimensionality being se_ratio times the 3x3x3 conv dim.
+
+ activation (callable): a callable that constructs activation layer, examples
+ include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
+ activation).
+ inner_act (callable): whether use Swish activation for act_b or not.
+
+ Returns:
+ (nn.Module): X3D stage layer.
+ """
+ res_blocks = []
+ for idx in range(depth):
+ block = create_x3d_res_block(
+ dim_in=dim_in if idx == 0 else dim_out,
+ dim_inner=dim_inner,
+ dim_out=dim_out,
+ bottleneck=bottleneck,
+ conv_kernel_size=conv_kernel_size,
+ conv_stride=conv_stride if idx == 0 else (1, 1, 1),
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ se_ratio=(se_ratio if (idx + 1) % 2 else 0.0),
+ activation=activation,
+ inner_act=inner_act,
+ )
+ res_blocks.append(block)
+
+ return ResStage(res_blocks=nn.ModuleList(res_blocks))
+
+
+def create_x3d_head(
+ *,
+ # Projection configs.
+ dim_in: int,
+ dim_inner: int,
+ dim_out: int,
+ num_classes: int,
+ # Pooling configs.
+ pool_act: Callable = nn.ReLU,
+ pool_kernel_size: Tuple[int] = (13, 5, 5),
+ # BN configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ bn_lin5_on=False,
+ # Dropout configs.
+ dropout_rate: float = 0.5,
+ # Activation configs.
+ activation: Callable = nn.Softmax,
+ # Output configs.
+ output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ Creates X3D head. This layer performs an projected pooling operation followed
+ by an dropout, a fully-connected projection, an activation layer and a global
+ spatiotemporal averaging.
+
+ ::
+
+ ProjectedPool
+ ↓
+ Dropout
+ ↓
+ Projection
+ ↓
+ Activation
+ ↓
+ Averaging
+
+ Args:
+ dim_in (int): input channel size of the X3D head.
+ dim_inner (int): intermediate channel size of the X3D head.
+ dim_out (int): output channel size of the X3D head.
+ num_classes (int): the number of classes for the video dataset.
+
+ pool_act (callable): a callable that constructs resnet pool activation
+ layer such as nn.ReLU.
+ pool_kernel_size (tuple): pooling kernel size(s) when not using adaptive
+ pooling.
+
+ norm (callable): a callable that constructs normalization layer, examples
+ include nn.BatchNorm3d, None (not performing normalization).
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+ bn_lin5_on (bool): if True, perform normalization on the features
+ before the classifier.
+
+ dropout_rate (float): dropout rate.
+
+ activation (callable): a callable that constructs resnet head activation
+ layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not
+ applying activation).
+
+ output_with_global_average (bool): if True, perform global averaging on temporal
+ and spatial dimensions and reshape output to batch_size x out_features.
+
+ Returns:
+ (nn.Module): X3D head layer.
+ """
+ pre_conv_module = nn.Conv3d(
+ in_channels=dim_in, out_channels=dim_inner, kernel_size=(1, 1, 1), bias=False
+ )
+
+ pre_norm_module = norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
+ pre_act_module = None if pool_act is None else pool_act()
+
+ if pool_kernel_size is None:
+ pool_module = nn.AdaptiveAvgPool3d((1, 1, 1))
+ else:
+ pool_module = nn.AvgPool3d(pool_kernel_size, stride=1)
+
+ post_conv_module = nn.Conv3d(
+ in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
+ )
+
+ if bn_lin5_on:
+ post_norm_module = norm(
+ num_features=dim_out, eps=norm_eps, momentum=norm_momentum
+ )
+ else:
+ post_norm_module = None
+ post_act_module = None if pool_act is None else pool_act()
+
+ projected_pool_module = ProjectedPool(
+ pre_conv=pre_conv_module,
+ pre_norm=pre_norm_module,
+ pre_act=pre_act_module,
+ pool=pool_module,
+ post_conv=post_conv_module,
+ post_norm=post_norm_module,
+ post_act=post_act_module,
+ )
+
+ if activation is None:
+ activation_module = None
+ elif activation == nn.Softmax:
+ activation_module = activation(dim=1)
+ elif activation == nn.Sigmoid:
+ activation_module = activation()
+ else:
+ raise NotImplementedError(
+ "{} is not supported as an activation" "function.".format(activation)
+ )
+
+ if output_with_global_average:
+ output_pool = nn.AdaptiveAvgPool3d(1)
+ else:
+ output_pool = None
+
+ return ResNetBasicHead(
+ proj=nn.Linear(dim_out, num_classes, bias=True),
+ activation=activation_module,
+ pool=projected_pool_module,
+ dropout=nn.Dropout(dropout_rate) if dropout_rate > 0 else None,
+ output_pool=output_pool,
+ )
+
+
+def create_x3d(
+ *,
+ # Input clip configs.
+ input_channel: int = 3,
+ input_clip_length: int = 13,
+ input_crop_size: int = 160,
+ # Model configs.
+ model_num_class: int = 400,
+ dropout_rate: float = 0.5,
+ width_factor: float = 2.0,
+ depth_factor: float = 2.2,
+ # Normalization configs.
+ norm: Callable = nn.BatchNorm3d,
+ norm_eps: float = 1e-5,
+ norm_momentum: float = 0.1,
+ # Activation configs.
+ activation: Callable = nn.ReLU,
+ # Stem configs.
+ stem_dim_in: int = 12,
+ stem_conv_kernel_size: Tuple[int] = (5, 3, 3),
+ stem_conv_stride: Tuple[int] = (1, 2, 2),
+ # Stage configs.
+ stage_conv_kernel_size: Tuple[Tuple[int]] = (
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ ),
+ stage_spatial_stride: Tuple[int] = (2, 2, 2, 2),
+ stage_temporal_stride: Tuple[int] = (1, 1, 1, 1),
+ bottleneck: Callable = create_x3d_bottleneck_block,
+ bottleneck_factor: float = 2.25,
+ se_ratio: float = 0.0625,
+ inner_act: Callable = Swish,
+ # Head configs.
+ head_dim_out: int = 2048,
+ head_pool_act: Callable = nn.ReLU,
+ head_bn_lin5_on: bool = False,
+ head_activation: Callable = None,
+ head_output_with_global_average: bool = True,
+) -> nn.Module:
+ """
+ X3D model builder. It builds a X3D network backbone, which is a ResNet.
+
+ Christoph Feichtenhofer.
+ "X3D: Expanding Architectures for Efficient Video Recognition."
+ https://arxiv.org/abs/2004.04730
+
+ ::
+
+ Input
+ ↓
+ Stem
+ ↓
+ Stage 1
+ ↓
+ .
+ .
+ .
+ ↓
+ Stage N
+ ↓
+ Head
+
+ Args:
+ input_channel (int): number of channels for the input video clip.
+ input_clip_length (int): length of the input video clip. Value for
+ different models: X3D-XS: 4; X3D-S: 13; X3D-M: 16; X3D-L: 16.
+ input_crop_size (int): spatial resolution of the input video clip.
+ Value for different models: X3D-XS: 160; X3D-S: 160; X3D-M: 224;
+ X3D-L: 312.
+
+ model_num_class (int): the number of classes for the video dataset.
+ dropout_rate (float): dropout rate.
+ width_factor (float): width expansion factor.
+ depth_factor (float): depth expansion factor. Value for different
+ models: X3D-XS: 2.2; X3D-S: 2.2; X3D-M: 2.2; X3D-L: 5.0.
+
+ norm (callable): a callable that constructs normalization layer.
+ norm_eps (float): normalization epsilon.
+ norm_momentum (float): normalization momentum.
+
+ activation (callable): a callable that constructs activation layer.
+
+ stem_dim_in (int): input channel size for stem before expansion.
+ stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
+ stem_conv_stride (tuple): convolutional stride size(s) of stem.
+
+ stage_conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
+ stage_spatial_stride (tuple): the spatial stride for each stage.
+ stage_temporal_stride (tuple): the temporal stride for each stage.
+ bottleneck_factor (float): bottleneck expansion factor for the 3x3x3 conv.
+ se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
+ channel dimensionality being se_ratio times the 3x3x3 conv dim.
+ inner_act (callable): whether use Swish activation for act_b or not.
+
+ head_dim_out (int): output channel size of the X3D head.
+ head_pool_act (callable): a callable that constructs resnet pool activation
+ layer such as nn.ReLU.
+ head_bn_lin5_on (bool): if True, perform normalization on the features
+ before the classifier.
+ head_activation (callable): a callable that constructs activation layer.
+ head_output_with_global_average (bool): if True, perform global averaging on
+ the head output.
+
+ Returns:
+ (nn.Module): the X3D network.
+ """
+
+ torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_x3d")
+
+ blocks = []
+ # Create stem for X3D.
+ stem_dim_out = round_width(stem_dim_in, width_factor)
+ stem = create_x3d_stem(
+ in_channels=input_channel,
+ out_channels=stem_dim_out,
+ conv_kernel_size=stem_conv_kernel_size,
+ conv_stride=stem_conv_stride,
+ conv_padding=[size // 2 for size in stem_conv_kernel_size],
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ activation=activation,
+ )
+ blocks.append(stem)
+
+ # Compute the depth and dimension for each stage
+ stage_depths = [1, 2, 5, 3]
+ exp_stage = 2.0
+ stage_dim1 = stem_dim_in
+ stage_dim2 = round_width(stage_dim1, exp_stage, divisor=8)
+ stage_dim3 = round_width(stage_dim2, exp_stage, divisor=8)
+ stage_dim4 = round_width(stage_dim3, exp_stage, divisor=8)
+ stage_dims = [stage_dim1, stage_dim2, stage_dim3, stage_dim4]
+
+ dim_in = stem_dim_out
+ # Create each stage for X3D.
+ for idx in range(len(stage_depths)):
+ dim_out = round_width(stage_dims[idx], width_factor)
+ dim_inner = int(bottleneck_factor * dim_out)
+ depth = round_repeats(stage_depths[idx], depth_factor)
+
+ stage_conv_stride = (
+ stage_temporal_stride[idx],
+ stage_spatial_stride[idx],
+ stage_spatial_stride[idx],
+ )
+
+ stage = create_x3d_res_stage(
+ depth=depth,
+ dim_in=dim_in,
+ dim_inner=dim_inner,
+ dim_out=dim_out,
+ bottleneck=bottleneck,
+ conv_kernel_size=stage_conv_kernel_size[idx],
+ conv_stride=stage_conv_stride,
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ se_ratio=se_ratio,
+ activation=activation,
+ inner_act=inner_act,
+ )
+ blocks.append(stage)
+ dim_in = dim_out
+
+ # Create head for X3D.
+ total_spatial_stride = stem_conv_stride[1] * np.prod(stage_spatial_stride)
+ total_temporal_stride = stem_conv_stride[0] * np.prod(stage_temporal_stride)
+
+ assert (
+ input_clip_length >= total_temporal_stride
+ ), "Clip length doesn't match temporal stride!"
+ assert (
+ input_crop_size >= total_spatial_stride
+ ), "Crop size doesn't match spatial stride!"
+
+ head_pool_kernel_size = (
+ input_clip_length // total_temporal_stride,
+ int(math.ceil(input_crop_size / total_spatial_stride)),
+ int(math.ceil(input_crop_size / total_spatial_stride)),
+ )
+
+ head = create_x3d_head(
+ dim_in=dim_out,
+ dim_inner=dim_inner,
+ dim_out=head_dim_out,
+ num_classes=model_num_class,
+ pool_act=head_pool_act,
+ pool_kernel_size=head_pool_kernel_size,
+ norm=norm,
+ norm_eps=norm_eps,
+ norm_momentum=norm_momentum,
+ bn_lin5_on=head_bn_lin5_on,
+ dropout_rate=dropout_rate,
+ activation=head_activation,
+ output_with_global_average=head_output_with_global_average,
+ )
+ blocks.append(head)
+ return Net(blocks=nn.ModuleList(blocks))
+
+
+class ProjectedPool(nn.Module):
+ """
+ A pooling module augmented with Conv, Normalization and Activation both
+ before and after pooling for the head layer of X3D.
+
+ ::
+
+ Conv3d (pre_conv)
+ ↓
+ Normalization (pre_norm)
+ ↓
+ Activation (pre_act)
+ ↓
+ Pool3d
+ ↓
+ Conv3d (post_conv)
+ ↓
+ Normalization (post_norm)
+ ↓
+ Activation (post_act)
+ """
+
+ def __init__(
+ self,
+ *,
+ pre_conv: nn.Module = None,
+ pre_norm: nn.Module = None,
+ pre_act: nn.Module = None,
+ pool: nn.Module = None,
+ post_conv: nn.Module = None,
+ post_norm: nn.Module = None,
+ post_act: nn.Module = None,
+ ) -> None:
+ """
+ Args:
+ pre_conv (torch.nn.modules): convolutional module.
+ pre_norm (torch.nn.modules): normalization module.
+ pre_act (torch.nn.modules): activation module.
+ pool (torch.nn.modules): pooling module.
+ post_conv (torch.nn.modules): convolutional module.
+ post_norm (torch.nn.modules): normalization module.
+ post_act (torch.nn.modules): activation module.
+ """
+ super().__init__()
+ set_attributes(self, locals())
+ assert self.pre_conv is not None
+ assert self.pool is not None
+ assert self.post_conv is not None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.pre_conv(x)
+
+ if self.pre_norm is not None:
+ x = self.pre_norm(x)
+ if self.pre_act is not None:
+ x = self.pre_act(x)
+
+ x = self.pool(x)
+ x = self.post_conv(x)
+
+ if self.post_norm is not None:
+ x = self.post_norm(x)
+ if self.post_act is not None:
+ x = self.post_act(x)
+ return x
diff --git a/code/pytorchvideo/pytorchvideo/neural_engine/detection_hook.py b/code/pytorchvideo/pytorchvideo/neural_engine/detection_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d7e5caa2a571faf233abdf3e46be924efea409e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/neural_engine/detection_hook.py
@@ -0,0 +1,152 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from collections import OrderedDict
+from typing import Callable
+
+import cv2
+import torch
+from hook import HookBase
+
+
+try:
+ from detectron2 import model_zoo
+ from detectron2.config import get_cfg
+ from detectron2.engine import DefaultPredictor
+except Exception as _:
+ raise ImportError(
+ "Install detectron2: https://detectron2.readthedocs.io/en/latest/tutorials/install.html"
+ )
+
+model_config = {
+ "backend": "detectron2",
+ "model": "COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml",
+ "threshold": 0.7,
+}
+
+
+def generate_predictor(model_config, *args):
+ if model_config["backend"] == "detectron2":
+ cfg = get_cfg()
+ cfg.MODEL.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ cfg.merge_from_file(model_zoo.get_config_file(model_config["model"]))
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = model_config["threshold"]
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_config["model"])
+
+ predictor = DefaultPredictor(
+ cfg,
+ )
+ else:
+ raise ValueError("Incorrect backend.")
+
+ return predictor
+
+
+def people_keypoints_executor(image, predictor):
+ return predictor(image)
+
+
+class PeopleKeypointDetectionHook(HookBase):
+ """
+ Performs keypoint detection for humans.
+
+ Args:
+ model_config (dict): configuration for the model. The dict-keys are
+ "backend", "model", and "threshold".
+ executor: function that generates predictions.
+ """
+
+ def __init__(
+ self,
+ model_config: dict = model_config,
+ executor: Callable = people_keypoints_executor,
+ ):
+ self.executor = executor
+ self.model_config = model_config
+ self.inputs = ["loaded_image", "bbox_coordinates"]
+ self.outputs = ["keypoint_coordinates"]
+
+ # generate different predictors for different backends.
+ self.predictor = generate_predictor(model_config=self.model_config)
+
+ def _run(self, status: OrderedDict):
+ inputs = status["loaded_image"]
+ outputs = self.executor(image=inputs, predictor=self.predictor)
+
+ if model_config["backend"] == "detectron2":
+ # keypoints is a tensor of shape (num_people, num_keypoint, (x, y, score))
+ keypoints = outputs["instances"][
+ outputs["instances"].pred_classes == 0
+ ].pred_keypoints
+
+ return {"keypoint_coordinates": keypoints}
+
+
+def image_load_executor(image_path):
+ # Returns an numpy array of shape (H,W,C) and dtype (uint8)
+ return cv2.imread(image_path)
+
+
+class ImageLoadHook(HookBase):
+ def __init__(self, executor: Callable = image_load_executor):
+ self.executor = executor
+ self.inputs = ["image_path"]
+ self.outputs = ["loaded_image"]
+
+ def _run(self, status: OrderedDict):
+ inputs = status["image_path"]
+ image_arr = self.executor(image_path=inputs)
+
+ return {"loaded_image": image_arr}
+
+
+def people_detection_executor(loaded_image, predictor):
+ # Returns a detectron2.structures.Boxes object
+ # that stores a list of boxes as a Nx4 torch.Tensor.
+ outputs = predictor(loaded_image)
+
+ people_bbox = outputs["instances"][
+ outputs["instances"].pred_classes == 0
+ ].pred_boxes
+
+ return people_bbox
+
+
+det_models = {
+ "faster_rcnn_R_50_C4": "COCO-Detection/faster_rcnn_R_50_C4_3x.yaml",
+ "faster_rcnn_R_50_FPN": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
+}
+
+
+class Detectron2PeopleDetectionHook(HookBase):
+ def __init__(
+ self,
+ executor: Callable = people_detection_executor,
+ model_name: str = "faster_rcnn_R_50_C4",
+ threshold=0.7,
+ ):
+ self.inputs = ["loaded_image"]
+ self.outputs = ["bbox_coordinates"]
+ self.executor = executor
+
+ # Configure detectron2
+ self.cfg = get_cfg()
+ self.model_config = det_models[model_name]
+ self.cfg.merge_from_file(model_zoo.get_config_file(self.model_config))
+ self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.model_config)
+ self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
+
+ if not torch.cuda.is_available():
+ self.cfg.MODEL.DEVICE = "cpu"
+
+ self.predictor = DefaultPredictor(self.cfg)
+
+ def _run(
+ self,
+ status,
+ ):
+ inputs = status["loaded_image"]
+ bbox_coordinates = self.executor(
+ loaded_image=inputs,
+ predictor=self.predictor,
+ )
+ return {"bbox_coordinates": bbox_coordinates}
diff --git a/code/pytorchvideo/pytorchvideo/neural_engine/engine.py b/code/pytorchvideo/pytorchvideo/neural_engine/engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..06347a7ce16d855f81bdc8d34eb2f1adb909428e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/neural_engine/engine.py
@@ -0,0 +1,77 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+from collections import OrderedDict
+from typing import List, Union
+
+import networkx as nx
+from pytorchvideo.neural_engine import HookBase
+
+
+class NeuralEngine:
+ """
+ NeuralEngine takes a list of hooks and executes them in their topological order. The
+ topological order of the hooks is determined by their required inputs and outputs.
+ """
+
+ def __init__(self, hooks: List[HookBase]) -> None:
+ self.hooks = hooks
+ self.execution_order_func = NeuralEngine.topological_sort
+
+ def get_execution_order(self, status):
+ self.execution_order_func(status, self.hooks)
+
+ def set_execution_order_func(self, func):
+ self.execution_order_func = func
+
+ @staticmethod
+ def topological_sort(status, hooks):
+ # Get DAG
+ graph = nx.DiGraph()
+ edges = []
+ pending_outputs = []
+ output_to_hook = {}
+ for hook in hooks:
+ for pair in itertools.product(hook.get_inputs(), hook.get_outputs()):
+ edges.append(pair)
+ for output in hook.get_outputs():
+ assert output not in pending_outputs
+ output_to_hook[output] = hook
+ pending_outputs.append(output)
+ graph.add_edges_from(edges)
+ for _current in nx.topological_sort(graph):
+ if _current in pending_outputs:
+ _hook = output_to_hook[_current]
+ yield _hook
+ for _hook_out in _hook.get_outputs():
+ pending_outputs.remove(_hook_out)
+ else:
+ assert _current in status
+ assert len(pending_outputs) == 0
+
+ def run(self, status: OrderedDict):
+ for hook in self.get_execution_order(status):
+ status.update(hook.run(status))
+ return status
+
+ def __enter__(
+ self,
+ ):
+ return self
+
+ def __exit__(
+ self,
+ type,
+ value,
+ traceback,
+ ):
+ pass
+
+ def __call__(
+ self,
+ status: Union[OrderedDict, str],
+ ):
+ # If not specified, the default input should be the path to video.
+ if type(status) == str:
+ status = {"path": status}
+ return self.run(status)
diff --git a/code/pytorchvideo/pytorchvideo/neural_engine/hook.py b/code/pytorchvideo/pytorchvideo/neural_engine/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cdc2238327c38df2d926541c09c1df367903eb8
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/neural_engine/hook.py
@@ -0,0 +1,154 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+
+from collections import OrderedDict
+from typing import Callable, List
+
+import attr
+import torch
+from pytorchvideo.data.decoder import DecoderType
+from pytorchvideo.data.encoded_video import EncodedVideo
+from pytorchvideo.transforms import (
+ ApplyTransformToKey,
+ ShortSideScale,
+ UniformTemporalSubsample,
+)
+from torchvision.transforms import Compose, Lambda
+from torchvision.transforms._transforms_video import CenterCropVideo, NormalizeVideo
+
+
+FAIL_STRATEGY = ("RANDOM_FILL", "ZERO_FILL", "RETURN_NONE", "RAISE_ERROR")
+HOOK_STATUS = ("PENDING", "SCHEDULED", "EXECUTING", "EXECUTED", "FAILED", "EARLY_EXIT")
+
+
+@attr.s(repr=True)
+class HookBase:
+ """
+ HookBase contains the basic attributes of a hook.
+ """
+
+ executor: Callable = attr.ib()
+ conditional_execution_func: Callable = attr.ib()
+ exit_func: Callable = attr.ib()
+ inputs: List[str] = attr.ib(default=())
+ outputs: List[str] = attr.ib(default=())
+ fail_strategy: str = attr.ib(
+ default="RAISE_ERROR",
+ validator=lambda self_, attr_, val_: (val_) in FAIL_STRATEGY,
+ )
+ priority: int = attr.ib(
+ default=1,
+ validator=lambda self_, attr_, val_: val_ >= 1,
+ )
+ status: str = "PENDING"
+
+ def run(
+ self,
+ status: OrderedDict,
+ ):
+ if self.conditional_execution_func():
+ self._run(status)
+ self.exit_func()
+
+ def _run(
+ self,
+ status: OrderedDict,
+ ):
+ pass
+
+ def get_inputs(
+ self,
+ ):
+ return self.inputs
+
+ def get_outputs(
+ self,
+ ):
+ return self.outputs
+
+
+def full_decode(status: OrderedDict, **args):
+ decoder = args.get("decoder", DecoderType.PYAV)
+ decode_audio = args.get("decode_audio", True)
+ video = EncodedVideo.from_path(status["path"], decode_audio, decoder)
+ frames = video.get_clip(0, video.duration)
+ return frames
+
+
+class DecodeHook(HookBase):
+ def __init__(
+ self,
+ executor: Callable = full_decode,
+ decode_audio: bool = True,
+ decoder: str = DecoderType.PYAV,
+ fail_strategy="RAISE_ERROR",
+ priority=0,
+ ):
+ # Decoding params
+ self.decode_audio = decode_audio
+ self.decoder = decoder
+ # Hook params
+ self.executor = executor
+ self.inputs = ["path"]
+ self.outputs = ["video", "audio"] if decode_audio else ["video"]
+ self.fail_strategy = fail_strategy
+ self.priority = priority
+
+ def _run(
+ self,
+ status: OrderedDict,
+ ):
+ frames = self.executor(
+ status, decode_audio=self.decode_audio, decoder=self.decoder
+ )
+ return frames
+
+
+class X3DClsHook(HookBase):
+ def __init__(
+ self,
+ executor: Callable = full_decode,
+ fail_strategy="RAISE_ERROR",
+ priority=0,
+ ):
+ # Hook params
+ self.executor = executor
+ self.inputs = ["video"]
+ self.outputs = ["action_class"]
+ self.fail_strategy = fail_strategy
+ self.priority = priority
+
+ side_size = 256
+ mean = [0.45, 0.45, 0.45]
+ std = [0.225, 0.225, 0.225]
+ crop_size = 256
+ num_frames = 32
+ model = "x3d_s"
+
+ self.transform = ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(num_frames),
+ Lambda(lambda x: x / 255.0),
+ NormalizeVideo(mean, std),
+ ShortSideScale(size=side_size),
+ CenterCropVideo(crop_size),
+ ]
+ ),
+ )
+ # Init network
+ self.model = torch.hub.load(
+ "facebookresearch/pytorchvideo", model=model, pretrained=True
+ )
+ self.model = self.model.eval()
+
+ def _run(
+ self,
+ status: OrderedDict,
+ ):
+ status = self.transform(status)
+ inputs = status["video"]
+ inputs = inputs[None, ...]
+ output = self.model(inputs)
+ return {"action_class": output}
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__init__.py b/code/pytorchvideo/pytorchvideo/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e08e45b271337f44d02fe86b3bb2e7addb4da21
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .augmix import AugMix # noqa
+from .mix import CutMix, MixUp, MixVideo # noqa
+from .rand_augment import RandAugment # noqa
+from .transforms import * # noqa
+from .transforms_factory import create_video_transform # noqa
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/__init__.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f59d06bcfec426a07d0fff65e2b5d3e2e510eba
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/__init__.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/augmentations.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/augmentations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e929b127b44cc2553d0c8687ee6711d959afd868
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/augmentations.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/augmix.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/augmix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..891e949a487c7cae73ee93b514d8f81636a43288
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/augmix.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/functional.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/functional.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2524f4890222b94d2ba7e17a659efd617933ba76
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/functional.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/mix.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/mix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b608d49c5a5282f2b1e97e25c8317e1027e563db
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/mix.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/rand_augment.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/rand_augment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69e0f2552c402eb29ed52f07886d221436c31cef
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/rand_augment.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/transforms.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4d5129baf3e74c5414fb1c7784c2da517fa9a29
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/transforms.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/__pycache__/transforms_factory.cpython-310.pyc b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/transforms_factory.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..697dc76ff1e45d8b5fdae4baa4208e51e440c5b6
Binary files /dev/null and b/code/pytorchvideo/pytorchvideo/transforms/__pycache__/transforms_factory.cpython-310.pyc differ
diff --git a/code/pytorchvideo/pytorchvideo/transforms/augmentations.py b/code/pytorchvideo/pytorchvideo/transforms/augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b4a6c584984e2bddf9de1460f8e923ac53f64a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/augmentations.py
@@ -0,0 +1,482 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""Video transforms that are used for advanced augmentation methods."""
+
+from typing import Any, Callable, Dict, Optional, Tuple
+
+import torch
+import torchvision
+import torchvision.transforms.functional_tensor as F_t
+from torchvision.transforms.functional import InterpolationMode
+
+
+# Maximum global magnitude used for video augmentation.
+_AUGMENTATION_MAX_LEVEL = 10
+
+
+def _check_fill_arg(kwargs):
+ """
+ Check if kwargs contains key ``fill``.
+ """
+ assert "fill" in kwargs, "Need to have fill in kwargs."
+
+
+def _autocontrast(video: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Maximize contrast of a video by remapping its pixels per channel so that the lowest
+ becomes black and the lightest becomes white.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ """
+ return torchvision.transforms.functional.autocontrast(video)
+
+
+def _equalize(video: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Equalize the histogram of a video by applying a non-linear mapping to the input in
+ order to create a uniform distribution of grayscale values in the output.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ """
+ if video.dtype != torch.uint8:
+ video_type = video.dtype
+ video = (video * 255).to(torch.uint8)
+ return (torchvision.transforms.functional.equalize(video) / 255).to(video_type)
+ return torchvision.transforms.functional.equalize(video)
+
+
+def _invert(video: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Invert the colors of a video.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ """
+ return torchvision.transforms.functional.invert(video)
+
+
+def _rotate(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor:
+ """
+ Rotate the image by angle.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): The rotation angle value in degrees, counter-clockwise.
+ """
+ _check_fill_arg(kwargs)
+ return torchvision.transforms.functional.rotate(
+ video, factor, fill=kwargs["fill"], interpolation=InterpolationMode.BILINEAR
+ )
+
+
+def _solarize(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor:
+ """
+ Solarize an video by inverting all pixel values above a threshold.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ """
+ if video.dtype == torch.uint8:
+ return torchvision.transforms.functional.solarize(video, int(factor * 255.0))
+ else:
+ return torchvision.transforms.functional.solarize(video, factor)
+
+
+def _adjust_contrast(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor:
+ """
+ Adjust contrast of an a video.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): How much to adjust the contrast. Can be any non-negative
+ number. 0 gives a solid gray video, 1 gives the original video while 2
+ increases the contrast by a factor of 2.
+ """
+ return torchvision.transforms.functional.adjust_contrast(video, factor)
+
+
+def _adjust_saturation(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor:
+ """
+ Adjust the saturation of a video.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): How much to adjust the saturation. 0 will give a black and
+ white video, 1 will give the original video while 2 will enhance the
+ saturation by a factor of 2.
+ """
+ return torchvision.transforms.functional.adjust_saturation(video, factor)
+
+
+def _adjust_brightness(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor:
+ """
+ Adjust brightness of a video.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ sharpness_factor (float): How much to adjust the sharpness. Can be any
+ non-negative number. 0 gives a blurred video, 1 gives the original video
+ while 2 increases the sharpness by a factor of 2.
+ """
+ return torchvision.transforms.functional.adjust_brightness(video, factor)
+
+
+def _adjust_sharpness(video: torch.Tensor, factor: float, **kwargs) -> torch.Tensor:
+ """
+ Adjust the sharpness of a video.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): How much to adjust the sharpness. Can be any non-negative
+ number. 0 gives a blurred video, 1 gives the original video while 2
+ increases the sharpness by a factor of 2.
+ """
+ return torchvision.transforms.functional.adjust_sharpness(video, factor)
+
+
+def _posterize(video: torch.Tensor, factor: float, **kwargs):
+ """
+ Posterize an image by reducing the number of bits for each color channel.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): The number of bits to keep for each channel (0-8).
+ """
+ if factor >= 8:
+ return video
+ if video.dtype != torch.uint8:
+ video_type = video.dtype
+ video = (video * 255).to(torch.uint8)
+ return (torchvision.transforms.functional.posterize(video, factor) / 255).to(
+ video_type
+ )
+ return torchvision.transforms.functional.posterize(video, factor)
+
+
+def _shear_x(video: torch.Tensor, factor: float, **kwargs):
+ """
+ Shear the video along the horizontal axis.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): How much to shear along the horizontal axis using the affine
+ matrix.
+ """
+ _check_fill_arg(kwargs)
+ translation_offset = video.size(-2) * factor / 2
+ return F_t.affine(
+ video,
+ [1, factor, translation_offset, 0, 1, 0],
+ fill=kwargs["fill"],
+ interpolation="bilinear",
+ )
+
+
+def _shear_y(video: torch.Tensor, factor: float, **kwargs):
+ """
+ Shear the video along the vertical axis.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): How much to shear along the vertical axis using the affine
+ matrix.
+ """
+ _check_fill_arg(kwargs)
+ translation_offset = video.size(-1) * factor / 2
+ return F_t.affine(
+ video,
+ [1, 0, 0, factor, 1, translation_offset],
+ fill=kwargs["fill"],
+ interpolation="bilinear",
+ )
+
+
+def _translate_x(video: torch.Tensor, factor: float, **kwargs):
+ """
+ Translate the video along the vertical axis.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): How much (relative to the image size) to translate along the
+ vertical axis.
+ """
+ _check_fill_arg(kwargs)
+ translation_offset = factor * video.size(-1)
+ return F_t.affine(
+ video,
+ [1, 0, translation_offset, 0, 1, 0],
+ fill=kwargs["fill"],
+ interpolation="bilinear",
+ )
+
+
+def _translate_y(video: torch.Tensor, factor: float, **kwargs):
+ """
+ Translate the video along the vertical axis.
+
+ Args:
+ video (torch.Tensor): Video tensor with shape (T, C, H, W).
+ factor (float): How much (relative to the image size) to translate along the
+ horizontal axis.
+ """
+ _check_fill_arg(kwargs)
+ translation_offset = factor * video.size(-2)
+ return F_t.affine(
+ video,
+ [1, 0, 0, 0, 1, translation_offset],
+ fill=kwargs["fill"],
+ interpolation="bilinear",
+ )
+
+
+def _randomly_negate(magnitude: float) -> float:
+ """
+ Negate input value with 50% chance.
+
+ Args:
+ magnitude (float): Input value.
+ """
+ return magnitude if torch.rand(1).item() > 0.5 else -magnitude
+
+
+def _increasing_magnitude_to_arg(level: int, params: Tuple[float, float]) -> float:
+ """
+ Convert level to transform magnitude. This assumes transform magnitude increases
+ linearly with level.
+
+ Args:
+ level (int): Level value.
+ params (Tuple[float, float]): Params contains two values: 1) Base transform
+ magnitude when level is 0; 2) Maxmimum increasing in transform magnitude
+ when level is at Maxmimum.
+ """
+ magnitude = (level / _AUGMENTATION_MAX_LEVEL) * params[1]
+ return (params[0] + magnitude,)
+
+
+def _increasing_randomly_negate_to_arg(
+ level: int, params: Tuple[float, float]
+) -> Tuple[float]:
+ """
+ Convert level to transform magnitude. This assumes transform magnitude increases
+ (or decreases with 50% chance) linearly with level.
+
+ Args:
+ level (int): Level value.
+ params (Tuple[float, float]): Params contains two values: 1) Base transform
+ magnitude when level is 0; 2) Maxmimum increasing in transform magnitude
+ when level is at maxmimum.
+ """
+ magnitude = (level / _AUGMENTATION_MAX_LEVEL) * params[1]
+ return (params[0] + _randomly_negate(magnitude),)
+
+
+def _decreasing_int_to_arg(level: int, params: Tuple[int, int]) -> Tuple[int]:
+ """
+ Convert level to transform magnitude. This assumes transform magnitude decreases
+ linearly with level. The return value is converted to int.
+
+ Args:
+ level (int): Level value.
+ params (Tuple[float, float]): Params contains two values: 1) Base transform
+ magnitude when level is 0; 2) Maxmimum decreasing in transform magnitude
+ when level is at maxmimum.
+ """
+ magnitude = (level / _AUGMENTATION_MAX_LEVEL) * params[1]
+ return (params[0] - int(magnitude),)
+
+
+def _decreasing_to_arg(level: int, params: Tuple[float, float]) -> Tuple[float]:
+ """
+ Convert level to transform magnitude. This assumes transform magnitude decreases
+ linearly with level.
+
+ Args:
+ level (int): Level value.
+ params (Tuple[float, float]): Params contains two values: 1) Base transform
+ magnitude when level is 0; 2) Maxmimum decreasing in transform magnitude
+ when level is at maxmimum.
+ """
+ magnitude = (level / _AUGMENTATION_MAX_LEVEL) * params[1]
+ return (params[0] - magnitude,)
+
+
+# A dictionary that contains transform names (key) and their corresponding transform
+# functions (value).
+_NAME_TO_TRANSFORM_FUNC = {
+ "AdjustBrightness": _adjust_brightness,
+ "AdjustContrast": _adjust_contrast,
+ "AdjustSaturation": _adjust_saturation,
+ "AdjustSharpness": _adjust_sharpness,
+ "AutoContrast": _autocontrast,
+ "Equalize": _equalize,
+ "Invert": _invert,
+ "Rotate": _rotate,
+ "Posterize": _posterize,
+ "Solarize": _solarize,
+ "ShearX": _shear_x,
+ "ShearY": _shear_y,
+ "TranslateX": _translate_x,
+ "TranslateY": _translate_y,
+}
+
+# A dictionary that contains transform names (key) and their corresponding level
+# functions (value), which converts the magnitude to the transform function arguments.
+_LEVEL_TO_ARG = {
+ "AdjustBrightness": _increasing_randomly_negate_to_arg,
+ "AdjustContrast": _increasing_randomly_negate_to_arg,
+ "AdjustSaturation": _increasing_randomly_negate_to_arg,
+ "AdjustSharpness": _increasing_randomly_negate_to_arg,
+ "AutoContrast": None,
+ "Equalize": None,
+ "Invert": None,
+ "Rotate": _increasing_randomly_negate_to_arg,
+ "Posterize": _decreasing_int_to_arg,
+ "Solarize": _decreasing_to_arg,
+ "ShearX": _increasing_randomly_negate_to_arg,
+ "ShearY": _increasing_randomly_negate_to_arg,
+ "TranslateX": _increasing_randomly_negate_to_arg,
+ "TranslateY": _increasing_randomly_negate_to_arg,
+}
+
+# A dictionary that contains transform names (key) and their corresponding maximum
+# transform (value).
+_TRANSFORM_MAX_PARAMS = {
+ "AdjustBrightness": (1, 0.9),
+ "AdjustContrast": (1, 0.9),
+ "AdjustSaturation": (1, 0.9),
+ "AdjustSharpness": (1, 0.9),
+ "AutoContrast": None,
+ "Equalize": None,
+ "Invert": None,
+ "Rotate": (0, 30),
+ "Posterize": (4, 4),
+ "Solarize": (1, 1),
+ "ShearX": (0, 0.3),
+ "ShearY": (0, 0.3),
+ "TranslateX": (0, 0.45),
+ "TranslateY": (0, 0.45),
+}
+
+# Hyperparameters for sampling magnitude.
+SAMPLING_DEFAULT_HPARAS = {"sampling_std": 0.5}
+
+# Hyperparameters for transform functions.
+TRANSFORM_DEFAULT_HPARAS = {"fill": (0.5, 0.5, 0.5)}
+
+
+class AugmentTransform:
+ def __init__(
+ self,
+ transform_name: str,
+ magnitude: int = 10,
+ prob: float = 0.5,
+ name_to_transform_func: Optional[Dict[str, Callable]] = None,
+ level_to_arg: Optional[Dict[str, Callable]] = None,
+ transform_max_paras: Optional[Dict[str, Tuple]] = None,
+ transform_hparas: Optional[Dict[str, Any]] = None,
+ sampling_type: str = "gaussian",
+ sampling_hparas: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """
+ The AugmentTransform composes a video transform that performs augmentation
+ based on a maximum magnitude. AugmentTransform also offers flexible ways to
+ generate augmentation magnitude based on different sampling strategies.
+
+ Args:
+ transform_name (str): The name of the video transform function.
+ magnitude (int): Magnitude used for transform function.
+ prob (float): The probablity of applying each transform function.
+ name_to_transform_func (Optional[Dict[str, Callable]]): A Dictionary that
+ contains mapping of the transform name to the transform function.
+ level_to_arg (Optional[Dict[str, Callable]]): A Dictionary that contains
+ mapping of the transform name to its level function, which converts
+ the the magnitude to the transform function arguments.
+ transform_max_paras (Optional[Dict[str, Tuple]]): A Dictionary that
+ contains mapping of the transform name to its maximum transform
+ magnitude.
+ transform_hparas (Optional[Dict[Any]]): Transform hyper parameters.
+ Needs to have key fill. By default, it uses transform_default_hparas.
+ sampling_type (str): Sampling method for magnitude of transform. It should
+ be either gaussian or uniform.
+ sampling_hparas (Optional[Dict[Any]]): Hyper parameters for sampling. If
+ gaussian sampling is used, it needs to have key sampling_std. By
+ default, it uses transform_default_hparas.
+ """
+
+ assert sampling_type in ["gaussian", "uniform"]
+ name_to_transform_func = name_to_transform_func or _NAME_TO_TRANSFORM_FUNC
+ level_to_arg = level_to_arg or _LEVEL_TO_ARG
+ transform_max_paras = transform_max_paras or _TRANSFORM_MAX_PARAMS
+ self.transform_hparas = transform_hparas or TRANSFORM_DEFAULT_HPARAS
+ self.sampling_type = sampling_type
+ self.sampling_hparas = sampling_hparas or SAMPLING_DEFAULT_HPARAS
+ assert "fill" in self.transform_hparas
+ if self.sampling_type == "gaussian":
+ assert "sampling_std" in self.sampling_hparas
+ if self.sampling_type == "uniform":
+ assert "sampling_data_type" in self.sampling_hparas
+ assert "sampling_min" in self.sampling_hparas
+ if self.sampling_hparas["sampling_data_type"] == "int":
+ assert isinstance(self.sampling_hparas["sampling_min"], int)
+ elif self.sampling_hparas["sampling_data_type"] == "float":
+ assert isinstance(self.sampling_hparas["sampling_min"], (int, float))
+ assert transform_name in name_to_transform_func
+
+ self.max_level = _AUGMENTATION_MAX_LEVEL
+ self.transform_name = transform_name
+ self.magnitude = magnitude
+ self.transform_fn = name_to_transform_func[transform_name]
+ self.level_fn = level_to_arg[transform_name]
+ self.level_paras = transform_max_paras[transform_name]
+ self.prob = prob
+ self.sampling_type = sampling_type
+
+ def _get_magnitude(self) -> float:
+ """
+ Get magnitude based on sampling type.
+ """
+ if self.sampling_type == "gaussian":
+ return max(
+ 0,
+ min(
+ self.max_level,
+ torch.normal(
+ self.magnitude, self.sampling_hparas["sampling_std"], size=(1,)
+ ).item(),
+ ),
+ )
+ elif self.sampling_type == "uniform":
+ if self.sampling_hparas["sampling_data_type"] == "int":
+ return torch.randint(
+ self.sampling_hparas["sampling_min"], self.magnitude + 1, size=(1,)
+ ).item()
+ elif self.sampling_hparas["sampling_data_type"] == "float":
+ return (
+ torch.rand(size=(1,)).item()
+ * (self.magnitude - self.sampling_hparas["sampling_min"])
+ + self.sampling_hparas["sampling_min"]
+ )
+ else:
+ raise ValueError("sampling_data_type must be either 'int' or 'float'")
+ else:
+ raise NotImplementedError
+
+ def __call__(self, video: torch.Tensor) -> torch.Tensor:
+ """
+ The input is a video tensor.
+
+ Args:
+ video (torch.Tensor): Input video tensor with shape (T, C, H, W).
+ """
+ if torch.rand(1).item() > self.prob:
+ return video
+ magnitude = self._get_magnitude()
+ level_args = (
+ self.level_fn(magnitude, self.level_paras)
+ if self.level_fn is not None
+ else ()
+ )
+ return self.transform_fn(video, *level_args, **self.transform_hparas)
diff --git a/code/pytorchvideo/pytorchvideo/transforms/augmix.py b/code/pytorchvideo/pytorchvideo/transforms/augmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f235c8a2d313553517068bb5d00fc307cb284c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/augmix.py
@@ -0,0 +1,147 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Dict, Optional
+
+import torch
+from pytorchvideo.transforms.augmentations import (
+ _AUGMENTATION_MAX_LEVEL,
+ _decreasing_int_to_arg,
+ _decreasing_to_arg,
+ _increasing_magnitude_to_arg,
+ _increasing_randomly_negate_to_arg,
+ AugmentTransform,
+)
+from pytorchvideo.transforms.transforms import OpSampler
+
+
+_AUGMIX_LEVEL_TO_ARG = {
+ "AutoContrast": None,
+ "Equalize": None,
+ "Rotate": _increasing_randomly_negate_to_arg,
+ "Posterize": _decreasing_int_to_arg,
+ "Solarize": _decreasing_to_arg,
+ "ShearX": _increasing_randomly_negate_to_arg,
+ "ShearY": _increasing_randomly_negate_to_arg,
+ "TranslateX": _increasing_randomly_negate_to_arg,
+ "TranslateY": _increasing_randomly_negate_to_arg,
+ "AdjustSaturation": _increasing_magnitude_to_arg,
+ "AdjustContrast": _increasing_magnitude_to_arg,
+ "AdjustBrightness": _increasing_magnitude_to_arg,
+ "AdjustSharpness": _increasing_magnitude_to_arg,
+}
+
+_TRANSFORM_AUGMIX_MAX_PARAMS = {
+ "AutoContrast": None,
+ "Equalize": None,
+ "Rotate": (0, 30),
+ "Posterize": (4, 4),
+ "Solarize": (1, 1),
+ "ShearX": (0, 0.3),
+ "ShearY": (0, 0.3),
+ "TranslateX": (0, 1.0 / 3.0),
+ "TranslateY": (0, 1.0 / 3.0),
+ "AdjustSaturation": (0.1, 1.8),
+ "AdjustContrast": (0.1, 1.8),
+ "AdjustBrightness": (0.1, 1.8),
+ "AdjustSharpness": (0.1, 1.8),
+}
+
+# Hyperparameters for sampling magnitude.
+# sampling_data_type determines whether uniform sampling samples among ints or floats.
+# sampling_min determines the minimum possible value obtained from uniform
+# sampling among floats.
+SAMPLING_AUGMIX_DEFAULT_HPARAS = {"sampling_data_type": "float", "sampling_min": 0.1}
+
+
+class AugMix:
+ """
+ This implements AugMix for video. AugMix generates several chains of augmentations
+ on the original video, which are then mixed together with each other and with the
+ original video to create an augmented video. The input video tensor should have
+ shape (T, C, H, W).
+
+ AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty
+ (https://arxiv.org/pdf/1912.02781.pdf)
+ """
+
+ def __init__(
+ self,
+ magnitude: int = 3,
+ alpha: float = 1.0,
+ width: int = 3,
+ depth: int = -1,
+ transform_hparas: Optional[Dict[str, Any]] = None,
+ sampling_hparas: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """
+ Args:
+ magnitude (int): Magnitude used for transform function. Default is 3.
+ alpha (float): Parameter for choosing mixing weights from the beta
+ and Dirichlet distributions. Default is 1.0.
+ width (int): The number of transformation chains. Default is 3.
+ depth (int): The number of transformations in each chain. If depth is -1,
+ each chain will have a random length between 1 and 3 inclusive.
+ Default is -1.
+ transform_hparas (Optional[Dict[Any]]): Transform hyper parameters.
+ Needs to have key fill. By default, the fill value is (0.5, 0.5, 0.5).
+ sampling_hparas (Optional[Dict[Any]]): Hyper parameters for sampling. If
+ gaussian sampling is used, it needs to have key sampling_std. By
+ default, it uses SAMPLING_AUGMIX_DEFAULT_HPARAS.
+ """
+ assert isinstance(magnitude, int), "magnitude must be an int"
+ assert (
+ magnitude >= 1 and magnitude <= _AUGMENTATION_MAX_LEVEL
+ ), f"magnitude must be between 1 and {_AUGMENTATION_MAX_LEVEL} inclusive"
+ assert alpha > 0.0, "alpha must be greater than 0"
+ assert width > 0, "width must be greater than 0"
+
+ self._magnitude = magnitude
+
+ self.dirichlet = torch.distributions.dirichlet.Dirichlet(
+ torch.tensor([alpha] * width)
+ )
+ self.beta = torch.distributions.beta.Beta(alpha, alpha)
+
+ transforms_list = [
+ AugmentTransform(
+ transform_name=transform_name,
+ magnitude=self._magnitude,
+ prob=1.0,
+ level_to_arg=_AUGMIX_LEVEL_TO_ARG,
+ transform_max_paras=_TRANSFORM_AUGMIX_MAX_PARAMS,
+ transform_hparas=transform_hparas,
+ sampling_type="uniform",
+ sampling_hparas=sampling_hparas or SAMPLING_AUGMIX_DEFAULT_HPARAS,
+ )
+ for transform_name in list(_TRANSFORM_AUGMIX_MAX_PARAMS.keys())
+ ]
+ if depth > 0:
+ self.augmix_fn = OpSampler(
+ transforms_list,
+ num_sample_op=depth,
+ replacement=True,
+ )
+ else:
+ self.augmix_fn = OpSampler(
+ transforms_list,
+ num_sample_op=3,
+ randomly_sample_depth=True,
+ replacement=True,
+ )
+
+ def __call__(self, video: torch.Tensor) -> torch.Tensor:
+ """
+ Perform AugMix to the input video tensor.
+
+ Args:
+ video (torch.Tensor): Input video tensor with shape (T, C, H, W).
+ """
+ mixing_weights = self.dirichlet.sample()
+ m = self.beta.sample().item()
+ mixed = torch.zeros(video.shape, dtype=torch.float32)
+ for mw in mixing_weights:
+ mixed += mw * self.augmix_fn(video)
+ if video.dtype == torch.uint8:
+ return (m * video + (1 - m) * mixed).type(torch.uint8)
+ else:
+ return m * video + (1 - m) * mixed
diff --git a/code/pytorchvideo/pytorchvideo/transforms/functional.py b/code/pytorchvideo/pytorchvideo/transforms/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..59a58d9da23e34db3229aba20f84c72a4b2af201
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/functional.py
@@ -0,0 +1,615 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import copy
+import math
+from typing import Tuple
+
+import numpy as np
+import torch
+
+
+try:
+ import cv2
+except ImportError:
+ _HAS_CV2 = False
+else:
+ _HAS_CV2 = True
+
+
+def uniform_temporal_subsample(
+ x: torch.Tensor, num_samples: int, temporal_dim: int = -3
+) -> torch.Tensor:
+ """
+ Uniformly subsamples num_samples indices from the temporal dimension of the video.
+ When num_samples is larger than the size of temporal dimension of the video, it
+ will sample frames based on nearest neighbor interpolation.
+
+ Args:
+ x (torch.Tensor): A video tensor with dimension larger than one with torch
+ tensor type includes int, long, float, complex, etc.
+ num_samples (int): The number of equispaced samples to be selected
+ temporal_dim (int): dimension of temporal to perform temporal subsample.
+
+ Returns:
+ An x-like Tensor with subsampled temporal dimension.
+ """
+ t = x.shape[temporal_dim]
+ assert num_samples > 0 and t > 0
+ # Sample by nearest neighbor interpolation if num_samples > t.
+ indices = torch.linspace(0, t - 1, num_samples)
+ indices = torch.clamp(indices, 0, t - 1).long()
+ return torch.index_select(x, temporal_dim, indices)
+
+
+@torch.jit.ignore
+def _interpolate_opencv(
+ x: torch.Tensor, size: Tuple[int, int], interpolation: str
+) -> torch.Tensor:
+ """
+ Down/up samples the input torch tensor x to the given size with given interpolation
+ mode.
+ Args:
+ input (Tensor): the input tensor to be down/up sampled.
+ size (Tuple[int, int]): expected output spatial size.
+ interpolation: model to perform interpolation, options include `nearest`,
+ `linear`, `bilinear`, `bicubic`.
+ """
+ if not _HAS_CV2:
+ raise ImportError(
+ "opencv is required to use opencv transforms. Please "
+ "install with 'pip install opencv-python'."
+ )
+
+ _opencv_pytorch_interpolation_map = {
+ "nearest": cv2.INTER_NEAREST,
+ "linear": cv2.INTER_LINEAR,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ }
+ assert interpolation in _opencv_pytorch_interpolation_map
+ new_h, new_w = size
+ img_array_list = [
+ img_tensor.squeeze(0).numpy()
+ for img_tensor in x.permute(1, 2, 3, 0).split(1, dim=0)
+ ]
+ resized_img_array_list = [
+ cv2.resize(
+ img_array,
+ (new_w, new_h), # The input order for OpenCV is w, h.
+ interpolation=_opencv_pytorch_interpolation_map[interpolation],
+ )
+ for img_array in img_array_list
+ ]
+ img_array = np.concatenate(
+ [np.expand_dims(img_array, axis=0) for img_array in resized_img_array_list],
+ axis=0,
+ )
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_array))
+ img_tensor = img_tensor.permute(3, 0, 1, 2)
+ return img_tensor
+
+
+def short_side_scale(
+ x: torch.Tensor,
+ size: int,
+ interpolation: str = "bilinear",
+ backend: str = "pytorch",
+) -> torch.Tensor:
+ """
+ Determines the shorter spatial dim of the video (i.e. width or height) and scales
+ it to the given size. To maintain aspect ratio, the longer side is then scaled
+ accordingly.
+ Args:
+ x (torch.Tensor): A video tensor of shape (C, T, H, W) and type torch.float32.
+ size (int): The size the shorter side is scaled to.
+ interpolation (str): Algorithm used for upsampling,
+ options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'
+ backend (str): backend used to perform interpolation. Options includes
+ `pytorch` as default, and `opencv`. Note that opencv and pytorch behave
+ differently on linear interpolation on some versions.
+ https://discuss.pytorch.org/t/pytorch-linear-interpolation-is-different-from-pil-opencv/71181
+ Returns:
+ An x-like Tensor with scaled spatial dims.
+ """ # noqa
+ assert len(x.shape) == 4
+ assert x.dtype == torch.float32
+ assert backend in ("pytorch", "opencv")
+ c, t, h, w = x.shape
+ if w < h:
+ new_h = int(math.floor((float(h) / w) * size))
+ new_w = size
+ else:
+ new_h = size
+ new_w = int(math.floor((float(w) / h) * size))
+ if backend == "pytorch":
+ return torch.nn.functional.interpolate(
+ x, size=(new_h, new_w), mode=interpolation, align_corners=False
+ )
+ elif backend == "opencv":
+ return _interpolate_opencv(x, size=(new_h, new_w), interpolation=interpolation)
+ else:
+ raise NotImplementedError(f"{backend} backend not supported.")
+
+
+def uniform_temporal_subsample_repeated(
+ frames: torch.Tensor, frame_ratios: Tuple[int], temporal_dim: int = -3
+) -> Tuple[torch.Tensor]:
+ """
+ Prepare output as a list of tensors subsampled from the input frames. Each tensor
+ maintain a unique copy of subsampled frames, which corresponds to a unique
+ pathway.
+
+ Args:
+ frames (tensor): frames of images sampled from the video. Expected to have
+ torch tensor (including int, long, float, complex, etc) with dimension
+ larger than one.
+ frame_ratios (tuple): ratio to perform temporal down-sampling for each pathways.
+ temporal_dim (int): dimension of temporal.
+
+ Returns:
+ frame_list (tuple): list of tensors as output.
+ """
+ temporal_length = frames.shape[temporal_dim]
+ frame_list = []
+ for ratio in frame_ratios:
+ pathway = uniform_temporal_subsample(
+ frames, temporal_length // ratio, temporal_dim
+ )
+ frame_list.append(pathway)
+
+ return frame_list
+
+
+def convert_to_one_hot(
+ targets: torch.Tensor,
+ num_class: int,
+ label_smooth: float = 0.0,
+) -> torch.Tensor:
+ """
+ This function converts target class indices to one-hot vectors,
+ given the number of classes.
+
+ Args:
+ targets (torch.Tensor): Index labels to be converted.
+ num_class (int): Total number of classes.
+ label_smooth (float): Label smooth value for non-target classes. Label smooth
+ is disabled by default (0).
+ """
+ assert (
+ torch.max(targets).item() < num_class
+ ), "Class Index must be less than number of classes"
+ assert 0 <= label_smooth < 1.0, "Label smooth value needs to be between 0 and 1."
+
+ non_target_value = label_smooth / num_class
+ target_value = 1.0 - label_smooth + non_target_value
+ one_hot_targets = torch.full(
+ (targets.shape[0], num_class),
+ non_target_value,
+ dtype=torch.long if label_smooth == 0.0 else None,
+ device=targets.device,
+ )
+ one_hot_targets.scatter_(1, targets.long().view(-1, 1), target_value)
+ return one_hot_targets
+
+
+def short_side_scale_with_boxes(
+ images: torch.Tensor,
+ boxes: torch.Tensor,
+ size: int,
+ interpolation: str = "bilinear",
+ backend: str = "pytorch",
+) -> Tuple[torch.Tensor, np.ndarray]:
+ """
+ Perform a spatial short scale jittering on the given images and
+ corresponding boxes.
+ Args:
+ images (tensor): images to perform scale jitter. Dimension is
+ `channel` x `num frames` x `height` x `width`.
+ boxes (tensor): Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ size (int): The size the shorter side is scaled to.
+ interpolation (str): Algorithm used for upsampling,
+ options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'
+ backend (str): backend used to perform interpolation. Options includes
+ `pytorch` as default, and `opencv`. Note that opencv and pytorch behave
+ differently on linear interpolation on some versions.
+ https://discuss.pytorch.org/t/pytorch-linear-interpolation-is-different-from-pil-opencv/71181
+ Returns:
+ (tensor): the scaled images with dimension of
+ `channel` x `num frames` x `height` x `width`.
+ (tensor): the scaled boxes with dimension of
+ `num boxes` x 4.
+ """
+ c, t, h, w = images.shape
+ images = short_side_scale(images, size, interpolation, backend)
+ _, _, new_h, new_w = images.shape
+ if w < h:
+ boxes *= float(new_h) / h
+ else:
+ boxes *= float(new_w) / w
+ return images, boxes
+
+
+def random_short_side_scale_with_boxes(
+ images: torch.Tensor,
+ boxes: torch.Tensor,
+ min_size: int,
+ max_size: int,
+ interpolation: str = "bilinear",
+ backend: str = "pytorch",
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Perform a spatial short scale jittering on the given images and
+ corresponding boxes.
+ Args:
+ images (tensor): images to perform scale jitter. Dimension is
+ `channel` x `num frames` x `height` x `width`.
+ boxes (tensor): Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ min_size (int): the minimal size to scale the frames.
+ max_size (int): the maximal size to scale the frames.
+ interpolation (str): Algorithm used for upsampling,
+ options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'
+ backend (str): backend used to perform interpolation. Options includes
+ `pytorch` as default, and `opencv`. Note that opencv and pytorch behave
+ differently on linear interpolation on some versions.
+ https://discuss.pytorch.org/t/pytorch-linear-interpolation-is-different-from-pil-opencv/71181
+ Returns:
+ (tensor): the scaled images with dimension of
+ `channel` x `num frames` x `height` x `width`.
+ (tensor): the scaled boxes with dimension of
+ `num boxes` x 4.
+ """
+ size = torch.randint(min_size, max_size + 1, (1,)).item()
+ return short_side_scale_with_boxes(images, boxes, size, interpolation, backend)
+
+
+def random_crop_with_boxes(
+ images: torch.Tensor, size: int, boxes: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Perform random spatial crop on the given images and corresponding boxes.
+ Args:
+ images (tensor): images to perform random crop. The dimension is
+ `channel` x `num frames` x `height` x `width`.
+ size (int): the size of height and width to crop on the image.
+ boxes (tensor): Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ cropped (tensor): cropped images with dimension of
+ `channel` x `num frames` x `height` x `width`.
+ cropped_boxes (tensor): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ if images.shape[2] == size and images.shape[3] == size:
+ return images
+ height = images.shape[2]
+ width = images.shape[3]
+ y_offset = 0
+ if height > size:
+ y_offset = int(np.random.randint(0, height - size))
+ x_offset = 0
+ if width > size:
+ x_offset = int(np.random.randint(0, width - size))
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
+
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset)
+ return cropped, clip_boxes_to_image(
+ cropped_boxes, cropped.shape[-2], cropped.shape[-1]
+ )
+
+
+def _uniform_crop_helper(images: torch.Tensor, size: int, spatial_idx: int):
+ """
+ A helper function grouping the common components in uniform crop
+ """
+ assert spatial_idx in [0, 1, 2]
+ height = images.shape[2]
+ width = images.shape[3]
+
+ y_offset = int(math.ceil((height - size) / 2))
+ x_offset = int(math.ceil((width - size) / 2))
+
+ if height > width:
+ if spatial_idx == 0:
+ y_offset = 0
+ elif spatial_idx == 2:
+ y_offset = height - size
+ else:
+ if spatial_idx == 0:
+ x_offset = 0
+ elif spatial_idx == 2:
+ x_offset = width - size
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
+
+ return cropped, x_offset, y_offset
+
+
+def uniform_crop(
+ images: torch.Tensor,
+ size: int,
+ spatial_idx: int,
+) -> torch.Tensor:
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `channel` x `num frames` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ Returns:
+ cropped (tensor): images with dimension of
+ `channel` x `num frames` x `height` x `width`.
+ """
+ cropped, _, _ = _uniform_crop_helper(images, size, spatial_idx)
+ return cropped
+
+
+def uniform_crop_with_boxes(
+ images: torch.Tensor,
+ size: int,
+ spatial_idx: int,
+ boxes: torch.Tensor,
+) -> Tuple[torch.Tensor, np.ndarray]:
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `channel` x `num frames` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ boxes (tensor): Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ cropped (tensor): images with dimension of
+ `channel` x `num frames` x `height` x `width`.
+ cropped_boxes (tensor): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ cropped, x_offset, y_offset = _uniform_crop_helper(images, size, spatial_idx)
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset)
+ return cropped, clip_boxes_to_image(
+ cropped_boxes, cropped.shape[-2], cropped.shape[-1]
+ )
+
+
+def horizontal_flip_with_boxes(
+ prob: float, images: torch.Tensor, boxes: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Perform horizontal flip on the given images and corresponding boxes.
+ Args:
+ prob (float): probility to flip the images.
+ images (tensor): images to perform horizontal flip, the dimension is
+ `channel` x `num frames` x `height` x `width`.
+ boxes (tensor): Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ images (tensor): images with dimension of
+ `channel` x `num frames` x `height` x `width`.
+ flipped_boxes (tensor): the flipped boxes with dimension of
+ `num boxes` x 4.
+ """
+ flipped_boxes = copy.deepcopy(boxes)
+
+ if np.random.uniform() < prob:
+ images = images.flip((-1))
+ width = images.shape[3]
+ flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
+
+ return images, flipped_boxes
+
+
+def clip_boxes_to_image(boxes: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ Clip an array of boxes to an image with the given height and width.
+ Args:
+ boxes (tensor): bounding boxes to perform clipping.
+ Dimension is `num boxes` x 4.
+ height (int): given image height.
+ width (int): given image width.
+ Returns:
+ clipped_boxes (tensor): the clipped boxes with dimension of
+ `num boxes` x 4.
+ """
+ clipped_boxes = copy.deepcopy(boxes)
+ clipped_boxes[:, [0, 2]] = np.minimum(
+ width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
+ )
+ clipped_boxes[:, [1, 3]] = np.minimum(
+ height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
+ )
+ return clipped_boxes
+
+
+def crop_boxes(boxes: torch.Tensor, x_offset: int, y_offset: int) -> torch.Tensor:
+ """
+ Peform crop on the bounding boxes given the offsets.
+ Args:
+ boxes (torch.Tensor): bounding boxes to peform crop. The dimension
+ is `num boxes` x 4.
+ x_offset (int): cropping offset in the x axis.
+ y_offset (int): cropping offset in the y axis.
+ Returns:
+ cropped_boxes (torch.Tensor): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ cropped_boxes = copy.deepcopy(boxes)
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
+
+ return cropped_boxes
+
+
+def _get_param_spatial_crop(
+ scale: Tuple[float, float],
+ ratio: Tuple[float, float],
+ height: int,
+ width: int,
+ log_uniform_ratio: bool = True,
+ num_tries: int = 10,
+) -> Tuple[int, int, int, int]:
+ """
+ Given scale, ratio, height and width, return sampled coordinates of the videos.
+
+ Args:
+ scale (Tuple[float, float]): Scale range of Inception-style area based
+ random resizing.
+ ratio (Tuple[float, float]): Aspect ratio range of Inception-style
+ area based random resizing.
+ height (int): Height of the original image.
+ width (int): Width of the original image.
+ log_uniform_ratio (bool): Whether to use a log-uniform distribution to
+ sample the aspect ratio. Default is True.
+ num_tries (int): The number of times to attempt a randomly resized crop.
+ Falls back to a central crop after all attempts are exhausted.
+ Default is 10.
+
+ Returns:
+ Tuple containing i, j, h, w. (i, j) are the coordinates of the top left
+ corner of the crop. (h, w) are the height and width of the crop.
+ """
+ assert num_tries >= 1, "num_tries must be at least 1"
+
+ if scale[0] > scale[1]:
+ scale = (scale[1], scale[0])
+ if ratio[0] > ratio[1]:
+ ratio = (ratio[1], ratio[0])
+
+ for _ in range(num_tries):
+ area = height * width
+ target_area = area * (scale[0] + torch.rand(1).item() * (scale[1] - scale[0]))
+ if log_uniform_ratio:
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(
+ log_ratio[0] + torch.rand(1).item() * (log_ratio[1] - log_ratio[0])
+ )
+ else:
+ aspect_ratio = ratio[0] + torch.rand(1).item() * (ratio[1] - ratio[0])
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if 0 < w <= width and 0 < h <= height:
+ i = torch.randint(0, height - h + 1, (1,)).item()
+ j = torch.randint(0, width - w + 1, (1,)).item()
+ return i, j, h, w
+
+ # Fallback to central crop.
+ in_ratio = float(width) / float(height)
+ if in_ratio < min(ratio):
+ w = width
+ h = int(round(w / min(ratio)))
+ elif in_ratio > max(ratio):
+ h = height
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = width
+ h = height
+ i = (height - h) // 2
+ j = (width - w) // 2
+ return i, j, h, w
+
+
+def random_resized_crop(
+ frames: torch.Tensor,
+ target_height: int,
+ target_width: int,
+ scale: Tuple[float, float],
+ aspect_ratio: Tuple[float, float],
+ shift: bool = False,
+ log_uniform_ratio: bool = True,
+ interpolation: str = "bilinear",
+ num_tries: int = 10,
+) -> torch.Tensor:
+ """
+ Crop the given images to random size and aspect ratio. A crop of random
+ size relative to the original size and a random aspect ratio is made. This
+ crop is finally resized to given size. This is popularly used to train the
+ Inception networks.
+
+ Args:
+ frames (torch.Tensor): Video tensor to be resized with shape (C, T, H, W).
+ target_height (int): Desired height after cropping.
+ target_width (int): Desired width after cropping.
+ scale (Tuple[float, float]): Scale range of Inception-style area based
+ random resizing. Should be between 0.0 and 1.0.
+ aspect_ratio (Tuple[float, float]): Aspect ratio range of Inception-style
+ area based random resizing. Should be between 0.0 and +infinity.
+ shift (bool): Bool that determines whether or not to sample two different
+ boxes (for cropping) for the first and last frame. If True, it then
+ linearly interpolates the two boxes for other frames. If False, the
+ same box is cropped for every frame. Default is False.
+ log_uniform_ratio (bool): Whether to use a log-uniform distribution to
+ sample the aspect ratio. Default is True.
+ interpolation (str): Algorithm used for upsampling. Currently supports
+ 'nearest', 'bilinear', 'bicubic', 'area'. Default is 'bilinear'.
+ num_tries (int): The number of times to attempt a randomly resized crop.
+ Falls back to a central crop after all attempts are exhausted.
+ Default is 10.
+
+ Returns:
+ cropped (tensor): A cropped video tensor of shape (C, T, target_height, target_width).
+ """
+ assert (
+ scale[0] > 0 and scale[1] > 0
+ ), "min and max of scale range must be greater than 0"
+ assert (
+ aspect_ratio[0] > 0 and aspect_ratio[1] > 0
+ ), "min and max of aspect_ratio range must be greater than 0"
+
+ channels = frames.shape[0]
+ t = frames.shape[1]
+ height = frames.shape[2]
+ width = frames.shape[3]
+
+ i, j, h, w = _get_param_spatial_crop(
+ scale, aspect_ratio, height, width, log_uniform_ratio, num_tries
+ )
+
+ if not shift:
+ cropped = frames[:, :, i : i + h, j : j + w]
+ return torch.nn.functional.interpolate(
+ cropped,
+ size=(target_height, target_width),
+ mode=interpolation,
+ )
+
+ i_, j_, h_, w_ = _get_param_spatial_crop(
+ scale, aspect_ratio, height, width, log_uniform_ratio, num_tries
+ )
+ i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
+ j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
+ h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
+ w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
+ cropped = torch.zeros((channels, t, target_height, target_width))
+ for ind in range(t):
+ cropped[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
+ frames[
+ :,
+ ind : ind + 1,
+ i_s[ind] : i_s[ind] + h_s[ind],
+ j_s[ind] : j_s[ind] + w_s[ind],
+ ],
+ size=(target_height, target_width),
+ mode=interpolation,
+ )
+ return cropped
+
+
+def div_255(x: torch.Tensor) -> torch.Tensor:
+ """
+ Divide the given tensor x by 255.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ y (torch.Tensor): Scaled tensor by dividing 255.
+ """
+ y = x / 255.0
+ return y
diff --git a/code/pytorchvideo/pytorchvideo/transforms/mix.py b/code/pytorchvideo/pytorchvideo/transforms/mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..64293fe896098093a5b6a48bce068b5a96356159
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/mix.py
@@ -0,0 +1,265 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Dict, Tuple
+
+import torch
+from pytorchvideo.transforms.functional import convert_to_one_hot
+
+
+def _mix_labels(
+ labels: torch.Tensor,
+ num_classes: int,
+ lam: float = 1.0,
+ label_smoothing: float = 0.0,
+ one_hot: bool = False,
+):
+ """
+ This function converts class indices to one-hot vectors and mix labels, given the
+ number of classes.
+
+ Args:
+ labels (torch.Tensor): Class labels.
+ num_classes (int): Total number of classes.
+ lam (float): lamba value for mixing labels.
+ label_smoothing (float): Label smoothing value.
+ """
+ if one_hot:
+ labels1 = labels
+ labels2 = labels.flip(0)
+ else:
+ labels1 = convert_to_one_hot(labels, num_classes, label_smoothing)
+ labels2 = convert_to_one_hot(labels.flip(0), num_classes, label_smoothing)
+ return labels1 * lam + labels2 * (1.0 - lam)
+
+
+class MixUp(torch.nn.Module):
+ """
+ Mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
+ """
+
+ def __init__(
+ self,
+ alpha: float = 1.0,
+ label_smoothing: float = 0.0,
+ num_classes: int = 400,
+ one_hot: bool = False,
+ ) -> None:
+ """
+ This implements MixUp for videos.
+
+ Args:
+ alpha (float): Mixup alpha value.
+ label_smoothing (float): Label smoothing value.
+ num_classes (int): Number of total classes.
+ """
+ super().__init__()
+ self.mixup_beta_sampler = torch.distributions.beta.Beta(alpha, alpha)
+ self.label_smoothing = label_smoothing
+ self.num_classes = num_classes
+ self.one_hot = one_hot
+
+ def forward(
+ self,
+ x_video: torch.Tensor,
+ labels: torch.Tensor,
+ **args: Any,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ The input is a batch of samples and their corresponding labels.
+
+ Args:
+ x (torch.Tensor): Input tensor. The input should be a batch of videos with
+ shape (B, C, T, H, W).
+ labels (torch.Tensor): Labels for input with shape (B).
+ Optional: x_audio: Audio input tensor.
+ """
+ assert x_video.size(0) > 1, "MixUp cannot be applied to a single instance."
+ mixup_lambda = self.mixup_beta_sampler.sample()
+ x_video_flipped = x_video.flip(0).mul_(1.0 - mixup_lambda)
+ x_video.mul_(mixup_lambda).add_(x_video_flipped)
+
+ new_labels = _mix_labels(
+ labels,
+ self.num_classes,
+ mixup_lambda,
+ self.label_smoothing,
+ one_hot=self.one_hot,
+ )
+
+ if args.get("x_audio", None) is not None:
+ x_audio = args["x_audio"]
+ assert x_audio.size(0) > 1, "MixUp cannot be applied to a single instance."
+ x_audio_flipped = x_audio.flip(0).mul_(1.0 - mixup_lambda)
+ x_audio.mul_(mixup_lambda).add_(x_audio_flipped)
+ return x_video, x_audio, new_labels
+ else:
+ return x_video, new_labels
+
+
+class CutMix(torch.nn.Module):
+ """
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
+ (https://arxiv.org/abs/1905.04899)
+ """
+
+ def __init__(
+ self,
+ alpha: float = 1.0,
+ label_smoothing: float = 0.0,
+ num_classes: int = 400,
+ one_hot: bool = False,
+ ) -> None:
+ """
+ This implements CutMix for videos.
+
+ Args:
+ alpha (float): CutMix alpha value.
+ label_smoothing (float): Label smoothing value.
+ num_classes (int): Number of total classes.
+ """
+ super().__init__()
+ self.one_hot = one_hot
+ self.cutmix_beta_sampler = torch.distributions.beta.Beta(alpha, alpha)
+ self.label_smoothing = label_smoothing
+ self.num_classes = num_classes
+
+ def _clip(self, value: int, min_value: int, max_value: int) -> int:
+ """
+ Clip value based on minimum value and maximum value.
+ """
+
+ return min(max(value, min_value), max_value)
+
+ def _get_rand_box(self, input_shape: Tuple[int], cutmix_lamda: float) -> Tuple[int]:
+ """
+ Get a random square box given a lambda value.
+ """
+
+ ratio = (1 - cutmix_lamda) ** 0.5
+ input_h, input_w = input_shape[-2:]
+ cut_h, cut_w = int(input_h * ratio), int(input_w * ratio)
+ cy = torch.randint(input_h, (1,)).item()
+ cx = torch.randint(input_w, (1,)).item()
+ yl = self._clip(cy - cut_h // 2, 0, input_h)
+ yh = self._clip(cy + cut_h // 2, 0, input_h)
+ xl = self._clip(cx - cut_w // 2, 0, input_w)
+ xh = self._clip(cx + cut_w // 2, 0, input_w)
+ return yl, yh, xl, xh
+
+ def _cutmix(
+ self, x: torch.Tensor, cutmix_lamda: float
+ ) -> Tuple[torch.Tensor, float]:
+ """
+ Perform CutMix and return corrected lambda value.
+ """
+
+ yl, yh, xl, xh = self._get_rand_box(x.size(), cutmix_lamda)
+ box_area = float((yh - yl) * (xh - xl))
+ cutmix_lamda_corrected = 1.0 - box_area / (x.size(-2) * x.size(-1))
+ x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
+ return x, cutmix_lamda_corrected
+
+ def forward(
+ self,
+ x_video: torch.Tensor,
+ labels: torch.Tensor,
+ **args: Any,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ The input is a batch of samples and their corresponding labels.
+
+ Args:
+ x (torch.Tensor): Input tensor. The input should be a batch of videos with
+ shape (B, C, T, H, W).
+ labels (torch.Tensor): Labels for input with shape (B).
+ """
+ assert x_video.size(0) > 1, "Cutmix cannot be applied to a single instance."
+ assert x_video.dim() == 4 or x_video.dim() == 5, "Please correct input shape."
+ cutmix_lamda = self.cutmix_beta_sampler.sample()
+ x_video, cutmix_lamda_corrected = self._cutmix(x_video, cutmix_lamda)
+ new_labels = _mix_labels(
+ labels,
+ self.num_classes,
+ cutmix_lamda_corrected,
+ self.label_smoothing,
+ one_hot=self.one_hot,
+ )
+ if args.get("x_audio", None) is not None:
+ x_audio = args["x_audio"]
+ assert x_audio.size(0) > 1, "Cutmix cannot be applied to a single instance."
+ assert (
+ x_audio.dim() == 4 or x_audio.dim() == 5
+ ), "Please correct input shape."
+ x_audio, _ = self._cutmix(x_audio, cutmix_lamda)
+ return x_video, x_audio, new_labels
+ else:
+ return x_video, new_labels
+
+
+class MixVideo(torch.nn.Module):
+ """
+ Stochastically applies either MixUp or CutMix to the input video.
+ """
+
+ def __init__(
+ self,
+ cutmix_prob: float = 0.5,
+ mixup_alpha: float = 1.0,
+ cutmix_alpha: float = 1.0,
+ label_smoothing: float = 0.0,
+ num_classes: int = 400,
+ one_hot: bool = False,
+ ):
+ """
+ Args:
+ cutmix_prob (float): Probability of using CutMix. MixUp will be used with
+ probability 1 - cutmix_prob. If cutmix_prob is 0, then MixUp is always
+ used. If cutmix_prob is 1, then CutMix is always used.
+ mixup_alpha (float): MixUp alpha value.
+ cutmix_alpha (float): CutMix alpha value.
+ label_smoothing (float): Label smoothing value.
+ num_classes (int): Number of total classes.
+ """
+
+ assert 0.0 <= cutmix_prob <= 1.0, "cutmix_prob should be between 0.0 and 1.0"
+
+ super().__init__()
+ self.cutmix_prob = cutmix_prob
+ self.mixup = MixUp(
+ alpha=mixup_alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ one_hot=one_hot,
+ )
+ self.cutmix = CutMix(
+ alpha=cutmix_alpha, label_smoothing=label_smoothing, num_classes=num_classes
+ )
+
+ # def forward(self, x: torch.Tensor, labels: torch.Tensor):
+ def forward(
+ self,
+ x_video: torch.Tensor,
+ labels: torch.Tensor,
+ **args: Any,
+ ) -> Dict[str, Any]:
+ """
+ The input is a batch of samples and their corresponding labels.
+
+ Args:
+ x (torch.Tensor): Input tensor. The input should be a batch of videos with
+ shape (B, C, T, H, W).
+ labels (torch.Tensor): Labels for input with shape (B).
+ """
+ if args.get("x_audio", None) is None:
+ if torch.rand(1).item() < self.cutmix_prob:
+ x_video, new_labels = self.cutmix(x_video, labels)
+ else:
+ x_video, new_labels = self.mixup(x_video, labels)
+ return x_video, new_labels
+ else:
+ x_audio = args["x_audio"]
+ if torch.rand(1).item() < self.cutmix_prob:
+ x_video, new_labels, x_audio = self.cutmix(x_video, labels, x_audio)
+ else:
+ x_video, new_labels, x_audio = self.mixup(x_video, labels, x_audio)
+ return x_video, x_audio, new_labels
diff --git a/code/pytorchvideo/pytorchvideo/transforms/rand_augment.py b/code/pytorchvideo/pytorchvideo/transforms/rand_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebcdab2197908fea5f99faa768f033fa4cff4529
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/rand_augment.py
@@ -0,0 +1,101 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Dict, Optional
+
+import torch
+from pytorchvideo.transforms.augmentations import AugmentTransform
+from pytorchvideo.transforms.transforms import OpSampler
+
+
+# A dictionary that contains transform names (key) and their corresponding maximum
+# transform magnitude (value).
+_TRANSFORM_RANDAUG_MAX_PARAMS = {
+ "AdjustBrightness": (1, 0.9),
+ "AdjustContrast": (1, 0.9),
+ "AdjustSaturation": (1, 0.9),
+ "AdjustSharpness": (1, 0.9),
+ "AutoContrast": None,
+ "Equalize": None,
+ "Invert": None,
+ "Rotate": (0, 30),
+ "Posterize": (4, 4),
+ "Solarize": (1, 1),
+ "ShearX": (0, 0.3),
+ "ShearY": (0, 0.3),
+ "TranslateX": (0, 0.45),
+ "TranslateY": (0, 0.45),
+}
+
+# Hyperparameters for sampling magnitude.
+# sampling_data_type determines whether uniform sampling samples among ints or floats.
+# sampling_min determines the minimum possible value obtained from uniform
+# sampling among floats.
+# sampling_std determines the standard deviation for gaussian sampling.
+SAMPLING_RANDAUG_DEFAULT_HPARAS = {
+ "sampling_data_type": "int",
+ "sampling_min": 0,
+ "sampling_std": 0.5,
+}
+
+
+class RandAugment:
+ """
+ This implements RandAugment for video. Assume the input video tensor with shape
+ (T, C, H, W).
+
+ RandAugment: Practical automated data augmentation with a reduced search space
+ (https://arxiv.org/abs/1909.13719)
+ """
+
+ def __init__(
+ self,
+ magnitude: int = 9,
+ num_layers: int = 2,
+ prob: float = 0.5,
+ transform_hparas: Optional[Dict[str, Any]] = None,
+ sampling_type: str = "gaussian",
+ sampling_hparas: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """
+ This implements RandAugment for video.
+
+ Args:
+ magnitude (int): Magnitude used for transform function.
+ num_layers (int): How many transform functions to apply for each
+ augmentation.
+ prob (float): The probablity of applying each transform function.
+ transform_hparas (Optional[Dict[Any]]): Transform hyper parameters.
+ Needs to have key fill. By default, it uses transform_default_hparas.
+ sampling_type (str): Sampling method for magnitude of transform. It should
+ be either gaussian or uniform.
+ sampling_hparas (Optional[Dict[Any]]): Hyper parameters for sampling. If
+ gaussian sampling is used, it needs to have key sampling_std. By
+ default, it uses SAMPLING_RANDAUG_DEFAULT_HPARAS.
+ """
+ assert sampling_type in ["gaussian", "uniform"]
+ sampling_hparas = sampling_hparas or SAMPLING_RANDAUG_DEFAULT_HPARAS
+ if sampling_type == "gaussian":
+ assert "sampling_std" in sampling_hparas
+
+ randaug_fn = [
+ AugmentTransform(
+ transform_name,
+ magnitude,
+ prob=prob,
+ transform_max_paras=_TRANSFORM_RANDAUG_MAX_PARAMS,
+ transform_hparas=transform_hparas,
+ sampling_type=sampling_type,
+ sampling_hparas=sampling_hparas,
+ )
+ for transform_name in list(_TRANSFORM_RANDAUG_MAX_PARAMS.keys())
+ ]
+ self.randaug_fn = OpSampler(randaug_fn, num_sample_op=num_layers)
+
+ def __call__(self, video: torch.Tensor) -> torch.Tensor:
+ """
+ Perform RandAugment to the input video tensor.
+
+ Args:
+ video (torch.Tensor): Input video tensor with shape (T, C, H, W).
+ """
+ return self.randaug_fn(video)
diff --git a/code/pytorchvideo/pytorchvideo/transforms/transforms.py b/code/pytorchvideo/pytorchvideo/transforms/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b981124a2cfb623a463705f1d7cff772fb1cba1
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/transforms.py
@@ -0,0 +1,431 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Callable, Dict, List, Optional, Tuple
+
+import pytorchvideo.transforms.functional
+import torch
+import torchvision.transforms
+
+
+class ApplyTransformToKey:
+ """
+ Applies transform to key of dictionary input.
+
+ Args:
+ key (str): the dictionary key the transform is applied to
+ transform (callable): the transform that is applied
+
+ Example:
+ >>> transforms.ApplyTransformToKey(
+ >>> key='video',
+ >>> transform=UniformTemporalSubsample(num_video_samples),
+ >>> )
+ """
+
+ def __init__(self, key: str, transform: Callable):
+ self._key = key
+ self._transform = transform
+
+ def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ x[self._key] = self._transform(x[self._key])
+ return x
+
+
+class RemoveKey(torch.nn.Module):
+ """
+ Removes the given key from the input dict. Useful for removing modalities from a
+ video clip that aren't needed.
+ """
+
+ def __init__(self, key: str):
+ super().__init__()
+ self._key = key
+
+ def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Args:
+ x (Dict[str, torch.Tensor]): video clip dict.
+ """
+ if self._key in x:
+ del x[self._key]
+ return x
+
+
+class UniformTemporalSubsample(torch.nn.Module):
+ """
+ ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.uniform_temporal_subsample``.
+ """
+
+ def __init__(self, num_samples: int, temporal_dim: int = -3):
+ """
+ Args:
+ num_samples (int): The number of equispaced samples to be selected
+ temporal_dim (int): dimension of temporal to perform temporal subsample.
+ """
+ super().__init__()
+ self._num_samples = num_samples
+ self._temporal_dim = temporal_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor with shape (C, T, H, W).
+ """
+ return pytorchvideo.transforms.functional.uniform_temporal_subsample(
+ x, self._num_samples, self._temporal_dim
+ )
+
+
+class UniformTemporalSubsampleRepeated(torch.nn.Module):
+ """
+ ``nn.Module`` wrapper for
+ ``pytorchvideo.transforms.functional.uniform_temporal_subsample_repeated``.
+ """
+
+ def __init__(self, frame_ratios: Tuple[int], temporal_dim: int = -3):
+ super().__init__()
+ self._frame_ratios = frame_ratios
+ self._temporal_dim = temporal_dim
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): video tensor with shape (C, T, H, W).
+ """
+ return pytorchvideo.transforms.functional.uniform_temporal_subsample_repeated(
+ x, self._frame_ratios, self._temporal_dim
+ )
+
+
+class ShortSideScale(torch.nn.Module):
+ """
+ ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.short_side_scale``.
+ """
+
+ def __init__(
+ self, size: int, interpolation: str = "bilinear", backend: str = "pytorch"
+ ):
+ super().__init__()
+ self._size = size
+ self._interpolation = interpolation
+ self._backend = backend
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor with shape (C, T, H, W).
+ """
+ return pytorchvideo.transforms.functional.short_side_scale(
+ x, self._size, self._interpolation, self._backend
+ )
+
+
+class RandomShortSideScale(torch.nn.Module):
+ """
+ ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.short_side_scale``. The size
+ parameter is chosen randomly in [min_size, max_size].
+ """
+
+ def __init__(
+ self,
+ min_size: int,
+ max_size: int,
+ interpolation: str = "bilinear",
+ backend: str = "pytorch",
+ ):
+ super().__init__()
+ self._min_size = min_size
+ self._max_size = max_size
+ self._interpolation = interpolation
+ self._backend = backend
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor with shape (C, T, H, W).
+ """
+ size = torch.randint(self._min_size, self._max_size + 1, (1,)).item()
+ return pytorchvideo.transforms.functional.short_side_scale(
+ x, size, self._interpolation, self._backend
+ )
+
+
+class UniformCropVideo(torch.nn.Module):
+ """
+ ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.uniform_crop``.
+ """
+
+ def __init__(
+ self, size: int, video_key: str = "video", aug_index_key: str = "aug_index"
+ ):
+ super().__init__()
+ self._size = size
+ self._video_key = video_key
+ self._aug_index_key = aug_index_key
+
+ def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Args:
+ x (Dict[str, torch.Tensor]): video clip dict.
+ """
+ x[self._video_key] = pytorchvideo.transforms.functional.uniform_crop(
+ x[self._video_key], self._size, x[self._aug_index_key]
+ )
+ return x
+
+
+class Normalize(torchvision.transforms.Normalize):
+ """
+ Normalize the (CTHW) video clip by mean subtraction and division by standard deviation
+
+ Args:
+ mean (3-tuple): pixel RGB mean
+ std (3-tuple): pixel RGB standard deviation
+ inplace (boolean): whether do in-place normalization
+ """
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor with shape (C, T, H, W).
+ """
+ vid = x.permute(1, 0, 2, 3) # C T H W to T C H W
+ vid = super().forward(vid)
+ vid = vid.permute(1, 0, 2, 3) # T C H W to C T H W
+ return vid
+
+
+class ConvertFloatToUint8(torch.nn.Module):
+ """
+ Converts a video from dtype float32 to dtype uint8.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.convert_func = torchvision.transforms.ConvertImageDtype(torch.uint8)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor with shape (C, T, H, W).
+ """
+ assert (
+ x.dtype == torch.float or x.dtype == torch.half
+ ), "image must have dtype torch.uint8"
+ return self.convert_func(x)
+
+
+class ConvertUint8ToFloat(torch.nn.Module):
+ """
+ Converts a video from dtype uint8 to dtype float32.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.convert_func = torchvision.transforms.ConvertImageDtype(torch.float32)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor with shape (C, T, H, W).
+ """
+ assert x.dtype == torch.uint8, "image must have dtype torch.uint8"
+ return self.convert_func(x)
+
+
+class MoveChannelRear(torch.nn.Module):
+ """
+ A Scriptable version to perform C X Y Z -> X Y Z C.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ @torch.jit.script_method
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor whose dimensions are to be permuted.
+ """
+ x = x.permute([1, 2, 3, 0])
+ return x
+
+
+class MoveChannelFront(torch.nn.Module):
+ """
+ A Scriptable version to perform X Y Z C -> C X Y Z.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ @torch.jit.script_method
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor whose dimensions are to be permuted.
+ """
+ x = x.permute([3, 0, 1, 2])
+ return x
+
+
+class RandomResizedCrop(torch.nn.Module):
+ """
+ ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.random_resized_crop``.
+ """
+
+ def __init__(
+ self,
+ target_height: int,
+ target_width: int,
+ scale: Tuple[float, float],
+ aspect_ratio: Tuple[float, float],
+ shift: bool = False,
+ log_uniform_ratio: bool = True,
+ interpolation: str = "bilinear",
+ num_tries: int = 10,
+ ) -> None:
+
+ super().__init__()
+ self._target_height = target_height
+ self._target_width = target_width
+ self._scale = scale
+ self._aspect_ratio = aspect_ratio
+ self._shift = shift
+ self._log_uniform_ratio = log_uniform_ratio
+ self._interpolation = interpolation
+ self._num_tries = num_tries
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): Input video tensor with shape (C, T, H, W).
+ """
+ return pytorchvideo.transforms.functional.random_resized_crop(
+ x,
+ self._target_height,
+ self._target_width,
+ self._scale,
+ self._aspect_ratio,
+ self._shift,
+ self._log_uniform_ratio,
+ self._interpolation,
+ self._num_tries,
+ )
+
+
+class Permute(torch.nn.Module):
+ """
+ Permutes the dimensions of a video.
+ """
+
+ def __init__(self, dims: Tuple[int]):
+ """
+ Args:
+ dims (Tuple[int]): The desired ordering of dimensions.
+ """
+ assert (
+ (d in dims) for d in range(len(dims))
+ ), "dims must contain every dimension (0, 1, 2, ...)"
+
+ super().__init__()
+ self._dims = dims
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): video tensor whose dimensions are to be permuted.
+ """
+ return x.permute(*self._dims)
+
+
+class OpSampler(torch.nn.Module):
+ """
+ Given a list of transforms with weights, OpSampler applies weighted sampling to
+ select n transforms, which are then applied sequentially to the input.
+ """
+
+ def __init__(
+ self,
+ transforms_list: List[Callable],
+ transforms_prob: Optional[List[float]] = None,
+ num_sample_op: int = 1,
+ randomly_sample_depth: bool = False,
+ replacement: bool = False,
+ ):
+ """
+ Args:
+ transforms_list (List[Callable]): A list of tuples of all available transforms
+ to sample from.
+ transforms_prob (Optional[List[float]]): The probabilities associated with
+ each transform in transforms_list. If not provided, the sampler assumes a
+ uniform distribution over all transforms. They do not need to sum up to one
+ but weights need to be positive.
+ num_sample_op (int): Number of transforms to sample and apply to input.
+ randomly_sample_depth (bool): If randomly_sample_depth is True, then uniformly
+ sample the number of transforms to apply, between 1 and num_sample_op.
+ replacement (bool): If replacement is True, transforms are drawn with replacement.
+ """
+ super().__init__()
+ assert len(transforms_list) > 0, "Argument transforms_list cannot be empty."
+ assert num_sample_op > 0, "Need to sample at least one transform."
+ assert num_sample_op <= len(
+ transforms_list
+ ), "Argument num_sample_op cannot be greater than number of available transforms."
+
+ if transforms_prob is not None:
+ assert len(transforms_list) == len(
+ transforms_prob
+ ), "Argument transforms_prob needs to have the same length as transforms_list."
+
+ assert (
+ min(transforms_prob) > 0
+ ), "Argument transforms_prob needs to be greater than 0."
+
+ self.transforms_list = transforms_list
+ self.transforms_prob = torch.FloatTensor(
+ transforms_prob
+ if transforms_prob is not None
+ else [1] * len(transforms_list)
+ )
+ self.num_sample_op = num_sample_op
+ self.randomly_sample_depth = randomly_sample_depth
+ self.replacement = replacement
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): Input tensor.
+ """
+ depth = (
+ torch.randint(1, self.num_sample_op + 1, (1,)).item()
+ if self.randomly_sample_depth
+ else self.num_sample_op
+ )
+ index_list = torch.multinomial(
+ self.transforms_prob, depth, replacement=self.replacement
+ )
+
+ for index in index_list:
+ x = self.transforms_list[index](x)
+
+ return x
+
+
+class Div255(torch.nn.Module):
+ """
+ ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.div_255``.
+ """
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Scale clip frames from [0, 255] to [0, 1].
+ Args:
+ x (Tensor): A tensor of the clip's RGB frames with shape:
+ (C, T, H, W).
+ Returns:
+ x (Tensor): Scaled tensor by dividing 255.
+ """
+ return torchvision.transforms.Lambda(
+ pytorchvideo.transforms.functional.div_255
+ )(x)
diff --git a/code/pytorchvideo/pytorchvideo/transforms/transforms_factory.py b/code/pytorchvideo/pytorchvideo/transforms/transforms_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f9f6283f74eff6d8a96892dd59fb5f46b110840
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo/transforms/transforms_factory.py
@@ -0,0 +1,274 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from pytorchvideo.transforms import (
+ ApplyTransformToKey,
+ AugMix,
+ ConvertUint8ToFloat,
+ Normalize,
+ Permute,
+ RandAugment,
+ RandomResizedCrop,
+ RandomShortSideScale,
+ RemoveKey,
+ ShortSideScale,
+ UniformTemporalSubsample,
+)
+from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip
+
+
+_RANDAUG_DEFAULT_PARAS = {
+ "magnitude": 9,
+ "num_layers": 2,
+ "prob": 0.5,
+ "transform_hparas": None,
+ "sampling_type": "gaussian",
+ "sampling_hparas": None,
+}
+
+_AUGMIX_DEFAULT_PARAS = {
+ "magnitude": 3,
+ "alpha": 1.0,
+ "width": 3,
+ "depth": -1,
+ "transform_hparas": None,
+ "sampling_hparas": None,
+}
+
+_RANDOM_RESIZED_CROP_DEFAULT_PARAS = {
+ "scale": (0.08, 1.0),
+ "aspect_ratio": (3.0 / 4.0, 4.0 / 3.0),
+}
+
+
+def _get_augmentation(
+ aug_type: str, aug_paras: Optional[Dict[str, Any]] = None
+) -> List[Callable]:
+ """
+ Initializes a list of callable transforms for video augmentation.
+
+ Args:
+ aug_type (str): Currently supports 'default', 'randaug', or 'augmix'.
+ Returns an empty list when aug_type is 'default'. Returns a list
+ of transforms containing RandAugment when aug_type is 'randaug'
+ and a list containing AugMix when aug_type is 'augmix'.
+ aug_paras (Dict[str, Any], optional): A dictionary that contains the necessary
+ parameters for the augmentation set in aug_type. If any parameters are
+ missing or if None, default parameters will be used. Default is None.
+
+ Returns:
+ aug (List[Callable]): List of callable transforms with the specified augmentation.
+ """
+
+ if aug_paras is None:
+ aug_paras = {}
+
+ if aug_type == "default":
+ aug = []
+ elif aug_type == "randaug":
+ aug = [
+ Permute((1, 0, 2, 3)),
+ RandAugment(
+ magnitude=aug_paras.get(
+ "magnitude", _RANDAUG_DEFAULT_PARAS["magnitude"]
+ ),
+ num_layers=aug_paras.get(
+ "num_layers", _RANDAUG_DEFAULT_PARAS["num_layers"]
+ ),
+ prob=aug_paras.get("prob", _RANDAUG_DEFAULT_PARAS["prob"]),
+ sampling_type=aug_paras.get(
+ "sampling_type", _RANDAUG_DEFAULT_PARAS["sampling_type"]
+ ),
+ sampling_hparas=aug_paras.get(
+ "sampling_hparas", _RANDAUG_DEFAULT_PARAS["sampling_hparas"]
+ ),
+ ),
+ Permute((1, 0, 2, 3)),
+ ]
+ elif aug_type == "augmix":
+ aug = [
+ Permute((1, 0, 2, 3)),
+ AugMix(
+ magnitude=aug_paras.get(
+ "magnitude", _AUGMIX_DEFAULT_PARAS["magnitude"]
+ ),
+ alpha=aug_paras.get("alpha", _AUGMIX_DEFAULT_PARAS["alpha"]),
+ width=aug_paras.get("width", _AUGMIX_DEFAULT_PARAS["width"]),
+ depth=aug_paras.get("depth", _AUGMIX_DEFAULT_PARAS["depth"]),
+ ),
+ Permute((1, 0, 2, 3)),
+ ]
+ else:
+ raise NotImplementedError
+
+ return aug
+
+
+def create_video_transform(
+ mode: str,
+ video_key: Optional[str] = None,
+ remove_key: Optional[List[str]] = None,
+ num_samples: Optional[int] = 8,
+ convert_to_float: bool = True,
+ video_mean: Tuple[float, float, float] = (0.45, 0.45, 0.45),
+ video_std: Tuple[float, float, float] = (0.225, 0.225, 0.225),
+ min_size: int = 256,
+ max_size: int = 320,
+ crop_size: Union[int, Tuple[int, int]] = 224,
+ horizontal_flip_prob: float = 0.5,
+ aug_type: str = "default",
+ aug_paras: Optional[Dict[str, Any]] = None,
+ random_resized_crop_paras: Optional[Dict[str, Any]] = None,
+) -> Union[
+ Callable[[torch.Tensor], torch.Tensor],
+ Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]],
+]:
+ """
+ Function that returns a factory default callable video transform, with default
+ parameters that can be modified. The transform that is returned depends on the
+ ``mode`` parameter: when in "train" mode, we use randomized transformations,
+ and when in "val" mode, we use the corresponding deterministic transformations.
+ Depending on whether ``video_key`` is set, the input to the transform can either
+ be a video tensor or a dict containing ``video_key`` that maps to a video
+ tensor. The video tensor should be of shape (C, T, H, W).
+
+ "train" mode "val" mode
+
+ (UniformTemporalSubsample) (UniformTemporalSubsample)
+ ↓
+ (RandAugment/AugMix) ↓
+ ↓
+ (ConvertUint8ToFloat) (ConvertUint8ToFloat)
+ ↓ ↓
+ Normalize Normalize
+ ↓ ↓
+ RandomResizedCrop/RandomShortSideScale+RandomCrop ShortSideScale+CenterCrop
+ ↓
+ RandomHorizontalFlip
+
+ (transform) = transform can be included or excluded in the returned
+ composition of transformations
+
+ Args:
+ mode (str): 'train' or 'val'. We use randomized transformations in
+ 'train' mode, and we use the corresponding deterministic transformation
+ in 'val' mode.
+ video_key (str, optional): Optional key for video value in dictionary input.
+ When video_key is None, the input is assumed to be a torch.Tensor.
+ Default is None.
+ remove_key (List[str], optional): Optional key to remove from a dictionary input.
+ Default is None.
+ num_samples (int, optional): The number of equispaced samples to be selected in
+ UniformTemporalSubsample. If None, then UniformTemporalSubsample will not be
+ used. Default is 8.
+ convert_to_float (bool): If True, converts images from uint8 to float.
+ Otherwise, leaves the image as is. Default is True.
+ video_mean (Tuple[float, float, float]): Sequence of means for each channel to
+ normalize to zero mean and unit variance. Default is (0.45, 0.45, 0.45).
+ video_std (Tuple[float, float, float]): Sequence of standard deviations for each
+ channel to normalize to zero mean and unit variance.
+ Default is (0.225, 0.225, 0.225).
+ min_size (int): Minimum size that the shorter side is scaled to for
+ RandomShortSideScale. If in "val" mode, this is the exact size
+ the the shorter side is scaled to for ShortSideScale.
+ Default is 256.
+ max_size (int): Maximum size that the shorter side is scaled to for
+ RandomShortSideScale. Default is 340.
+ crop_size (int or Tuple[int, int]): Desired output size of the crop for RandomCrop
+ in "train" mode and CenterCrop in "val" mode. If size is an int instead
+ of sequence like (h, w), a square crop (size, size) is made. Default is 224.
+ horizontal_flip_prob (float): Probability of the video being flipped in
+ RandomHorizontalFlip. Default value is 0.5.
+ aug_type (str): Currently supports 'default', 'randaug', or 'augmix'. No
+ augmentations other than RandomShortSideScale and RandomCrop area performed
+ when aug_type is 'default'. RandAugment is used when aug_type is 'randaug'
+ and AugMix is used when aug_type is 'augmix'. Default is 'default'.
+ aug_paras (Dict[str, Any], optional): A dictionary that contains the necessary
+ parameters for the augmentation set in aug_type. If any parameters are
+ missing or if None, default parameters will be used. Default is None.
+ random_resized_crop_paras (Dict[str, Any], optional): A dictionary that contains
+ the necessary parameters for Inception-style cropping. This crops the given
+ videos to random size and aspect ratio. A crop of random size relative to the
+ original size and a random aspect ratio is made. This crop is finally resized
+ to given size. This is popularly used to train the Inception networks. If any
+ parameters are missing or if None, default parameters in
+ _RANDOM_RESIZED_CROP_DEFAULT_PARAS will be used. If None, RandomShortSideScale
+ and RandomCrop will be used as a fallback. Default is None.
+
+ Returns:
+ A factory-default callable composition of transforms.
+ """
+
+ if isinstance(crop_size, int):
+ assert crop_size <= min_size, "crop_size must be less than or equal to min_size"
+ elif isinstance(crop_size, tuple):
+ assert (
+ max(crop_size) <= min_size
+ ), "the height and width in crop_size must be less than or equal to min_size"
+ else:
+ raise TypeError
+ if video_key is None:
+ assert remove_key is None, "remove_key should be None if video_key is None"
+ if aug_type == "default":
+ assert aug_paras is None, "aug_paras should be None for ``default`` aug_type"
+
+ if random_resized_crop_paras is not None:
+ random_resized_crop_paras["target_height"] = crop_size
+ random_resized_crop_paras["target_width"] = crop_size
+ if "scale" not in random_resized_crop_paras:
+ random_resized_crop_paras["scale"] = _RANDOM_RESIZED_CROP_DEFAULT_PARAS[
+ "scale"
+ ]
+ if "aspect_ratio" not in random_resized_crop_paras:
+ random_resized_crop_paras[
+ "aspect_ratio"
+ ] = _RANDOM_RESIZED_CROP_DEFAULT_PARAS["aspect_ratio"]
+
+ transform = Compose(
+ (
+ []
+ if num_samples is None
+ else [UniformTemporalSubsample(num_samples=num_samples)]
+ )
+ + (
+ _get_augmentation(aug_type=aug_type, aug_paras=aug_paras)
+ if mode == "train"
+ else []
+ )
+ + ([ConvertUint8ToFloat()] if convert_to_float else [])
+ + [Normalize(mean=video_mean, std=video_std)]
+ + (
+ (
+ [RandomResizedCrop(**random_resized_crop_paras)]
+ if random_resized_crop_paras is not None
+ else [
+ RandomShortSideScale(
+ min_size=min_size,
+ max_size=max_size,
+ ),
+ RandomCrop(size=crop_size),
+ ]
+ + [RandomHorizontalFlip(p=horizontal_flip_prob)]
+ )
+ if mode == "train"
+ else [
+ ShortSideScale(size=min_size),
+ CenterCrop(size=crop_size),
+ ]
+ )
+ )
+
+ if video_key is None:
+ return transform
+
+ return Compose(
+ [
+ ApplyTransformToKey(
+ key=video_key,
+ transform=transform,
+ )
+ ]
+ + ([] if remove_key is None else [RemoveKey(k) for k in remove_key])
+ )
diff --git a/code/pytorchvideo/pytorchvideo_trainer/README.md b/code/pytorchvideo/pytorchvideo_trainer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..46886c3384324a362bcb3411eeaa2187433eca49
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/README.md
@@ -0,0 +1,39 @@
+## PyTorchVideo Trainer
+
+A [PyTorch-Lightning]() based trainer supporting PytorchVideo models and dataloaders for various video understanding tasks.
+
+Currently supported tasks include:
+
+- Video Action Recognition: ResNet's, SlowFast Models, X3D models and MViT
+- Video Self-Supervised Learning: SimCLR, BYOL, MoCo
+- (Planned) Video Action Detection
+
+## Installation
+
+These instructions assumes that both pytorch and torchvision are already installed
+using the instructions in [INSTALL.md](https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md#requirements)
+
+Install the required additional dependency `recipes` by running the following command,
+```
+pip install "git+https://github.com/facebookresearch/recipes.git"
+```
+
+Post that, install PyTorchVideo Trainer by running,
+```
+git clone https://github.com/facebookresearch/pytorchvideo.git
+cd pytorchvideo/pytorchvideo_trainer
+pip install -e .
+
+# For developing and testing
+pip install -e . [test,dev]
+```
+
+## Testing
+
+Before running the tests, please ensure that you installed the necessary additional test dependencies.
+
+Use the the following command to run the tests:
+```
+# From the current directory
+python -m unittest discover -v -s ./tests
+```
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/__init__.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e164fdeeb52e22ed976f9759eedd732ced19ed2
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+
+def register_components() -> None:
+ """
+ Calls register_components() for all subfolders so we can register
+ subcomponents to Hydra's ConfigStore.
+ """
+ import pytorchvideo_trainer.datamodule.datamodule # noqa
+ import pytorchvideo_trainer.module.byol # noqa
+ import pytorchvideo_trainer.module.lr_policy # noqa
+ import pytorchvideo_trainer.module.moco_v2 # noqa
+ import pytorchvideo_trainer.module.optimizer # noqa
+ import pytorchvideo_trainer.module.simclr # noqa
+ import pytorchvideo_trainer.module.video_classification # noqa
+ import pytorchvideo_trainer.train_app # noqa
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cd87adf77f4e0c6ce58dc79efde5f91e92e3948
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .precise_batchnorm import PreciseBn # noqa
+
+
+__all__ = [
+ "PreciseBn",
+]
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b716ff843e5cae423ff2129a77cae9d000aac04
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py
@@ -0,0 +1,70 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+from typing import Generator
+
+import torch
+from fvcore.nn.precise_bn import update_bn_stats
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.core.lightning import LightningModule
+from pytorch_lightning.trainer.trainer import Trainer
+from torch.utils.data import DataLoader
+
+
+class PreciseBn(Callback):
+ """
+ Recompute and update the batch norm stats to make them more precise. During
+ training both BN stats and the weight are changing after every iteration, so
+ the running average can not precisely reflect the actual stats of the
+ current model.
+ In this callaback, the BN stats are recomputed with fixed weights, to make
+ the running average more precise during Training Phase. Specifically, it
+ computes the true average of per-batch mean/variance instead of the
+ running average. See Sec. 3 of the paper "Rethinking Batch in BatchNorm"
+ for details.
+ """
+
+ def __init__(self, num_batches: int) -> None:
+ """
+ Args:
+ num_batches (int): Number of steps / mini-batches to
+ perform to sample for updating the precise batchnorm
+ stats.
+ """
+ self.num_batches = num_batches
+
+ def _get_precise_bn_loader(
+ self, data_loader: DataLoader, pl_module: LightningModule
+ ) -> Generator[torch.Tensor, None, None]:
+ for batch in data_loader:
+ inputs = batch[pl_module.modality_key]
+ if isinstance(inputs, list):
+ inputs = [x.to(pl_module.device) for x in inputs]
+ else:
+ inputs = inputs.to(pl_module.device)
+ yield inputs
+
+ def on_train_epoch_end(
+ self,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ ) -> None:
+ """
+ Called at the end of every epoch only during the training
+ phase.
+
+ Args:
+ trainer (Trainer): A PyTorch-Lightning trainer object.
+ pl_module (LightningModule): A PyTorch-Lightning module.
+ Typically supported modules include -
+ pytorchvideo_trainer.module.VideoClassificationModule, etc.
+ """
+ # pyre-ignore[16]
+ dataloader = trainer.datamodule.train_dataloader()
+ precise_bn_loader = self._get_precise_bn_loader(
+ data_loader=dataloader, pl_module=pl_module
+ )
+ update_bn_stats(
+ model=pl_module.model, # pyre-ignore[6]
+ data_loader=precise_bn_loader,
+ num_iters=self.num_batches,
+ )
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5143209989eaec87959b18a835df24c51bb3bb00
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import torchrecipes.core.conf # noqa
+
+# Components to register with this config
+from pytorchvideo_trainer import register_components
+
+
+register_components()
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9a53d1af2c41ad552a75ab4efa1fd6be191644da
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml
@@ -0,0 +1,28 @@
+_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp
+
+defaults:
+ - schema/module: byol_module_conf
+ - schema/module/optim: optim_conf
+ - schema/datamodule: ptv_video_classification_data_module_conf
+ - datamodule/dataloader: kinetics_contrastive
+ - logger: ptl
+ - datamodule/transforms: kinetics_contrastive
+ - module/knn_memory: kinetics_k400
+ - module/model: slow_r50_byol
+ - module/loss: similarity
+ - module/optim: sgd_ssl
+ - module/metrics: accuracy
+ - schema/trainer: trainer
+ - trainer: cpu
+ - callbacks: null
+ - _self_
+trainer:
+ sync_batchnorm: false # set this to true for training
+
+module:
+ momentum_anneal_cosine: true
+
+hydra:
+ searchpath:
+ - pkg://pytorchvideo_trainer.conf
+ - pkg://torchrecipes.core.conf
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9b0934d86fe5a0b7dc0f552c6d722e0b00d4b451
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml
@@ -0,0 +1,3 @@
+precise_bn:
+ _target_: pytorchvideo_trainer.callbacks.precise_batchnorm.PreciseBn
+ num_batches: null
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ac10982e1622c943e81ba50249365d44de7f0c4e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml
@@ -0,0 +1,71 @@
+_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp
+
+defaults:
+ - schema/module: video_classification_module_conf_vision_transformer
+ - schema/module/optim: optim_conf
+ - schema/datamodule: ptv_video_classification_data_module_conf
+ - datamodule/dataloader: kinetics_classification
+ - logger: ptl
+ - datamodule/transforms: kinetics_classification_mvit_16x4
+ - module/model: mvit_base_16x4
+ - module/loss: soft_cross_entropy
+ - module/optim: adamw
+ - module/metrics: accuracy
+ - module/lr_scheduler: cosine_with_warmup
+ - schema/trainer: trainer
+ - trainer: multi_gpu
+ - _self_
+
+module:
+ clip_gradient_norm: 1.0
+ ensemble_method: "sum"
+ lr_scheduler:
+ max_iters: 200
+ warmup_start_lr: 1.6e-05
+ warmup_iters: 30
+ cosine_after_warmup: true
+ cosine_end_lr: 1.6e-05
+ optim:
+ lr: 0.0016
+ weight_decay: 0.05
+ method: adamw
+ zero_weight_decay_1d_param: true
+ batch_transform:
+ _target_: pytorchvideo_trainer.datamodule.transforms.MixVideoBatchWrapper
+ mixup_alpha: 0.8
+ cutmix_prob: 0.5
+ cutmix_alpha: 1.0
+ label_smoothing: 0.1
+
+datamodule:
+ dataloader:
+ train:
+ batch_size: 2
+ dataset:
+ clip_sampler:
+ clip_duration: 2.13
+ collate_fn:
+ _target_: pytorchvideo_trainer.datamodule.collators.build_collator_from_name
+ name: multiple_samples_collate
+ val:
+ batch_size: 8
+ dataset:
+ clip_sampler:
+ clip_duration: 2.13
+ test:
+ batch_size: 8
+ dataset:
+ clip_sampler:
+ clip_duration: 2.13
+
+trainer:
+ num_nodes: 16
+ gpus: 8
+ max_epochs: 200
+ sync_batchnorm: False
+ replace_sampler_ddp: False
+
+hydra:
+ searchpath:
+ - pkg://pytorchvideo_trainer.conf
+ - pkg://torchrecipes.core.conf
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5d246c91f535daa60745f3833d9a5d6442ec697b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml
@@ -0,0 +1,45 @@
+_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp
+
+defaults:
+ - schema/module: video_classification_module_conf
+ - schema/module/optim: optim_conf
+ - schema/datamodule: ptv_video_classification_data_module_conf
+ - datamodule/dataloader: kinetics_classification
+ - logger: ptl
+ - datamodule/transforms: kinetics_classification_slow
+ - module/model: slow_r50
+ - module/loss: cross_entropy
+ - module/optim: sgd
+ - module/metrics: accuracy
+ - module/lr_scheduler: cosine_with_warmup
+ - schema/trainer: trainer
+ - trainer: multi_gpu
+ - callbacks: precise_bn
+ - _self_
+
+module:
+ ensemble_method: "sum"
+ lr_scheduler:
+ max_iters: 196
+ warmup_start_lr: 0.01
+ warmup_iters: 34
+ optim:
+ lr: 0.8
+ nesterov: true
+
+callbacks:
+ precise_bn:
+ num_batches: 200
+
+trainer:
+ num_nodes: 8
+ gpus: 8
+ max_epochs: 196
+ sync_batchnorm: False
+ replace_sampler_ddp: False
+
+
+hydra:
+ searchpath:
+ - pkg://pytorchvideo_trainer.conf
+ - pkg://torchrecipes.core.conf
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a204081c85c1acb7cae2de87d79c64ad444e4fa8
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml
@@ -0,0 +1,45 @@
+_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp
+
+defaults:
+ - schema/module: video_classification_module_conf
+ - schema/module/optim: optim_conf
+ - schema/datamodule: ptv_video_classification_data_module_conf
+ - datamodule/dataloader: kinetics_classification
+ - logger: ptl
+ - datamodule/transforms: kinetics_classification_slowfast
+ - module/model: slowfast_r50
+ - module/loss: cross_entropy
+ - module/optim: sgd
+ - module/metrics: accuracy
+ - module/lr_scheduler: cosine_with_warmup
+ - schema/trainer: trainer
+ - trainer: multi_gpu
+ - callbacks: precise_bn
+ - _self_
+
+module:
+ ensemble_method: "sum"
+ lr_scheduler:
+ max_iters: 196
+ warmup_start_lr: 0.01
+ warmup_iters: 34
+ optim:
+ lr: 0.8
+ nesterov: true
+
+callbacks:
+ precise_bn:
+ num_batches: 200
+
+trainer:
+ num_nodes: 8
+ gpus: 8
+ max_epochs: 196
+ sync_batchnorm: False
+ replace_sampler_ddp: False
+
+
+hydra:
+ searchpath:
+ - pkg://pytorchvideo_trainer.conf
+ - pkg://torchrecipes.core.conf
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_x3d_xs.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_x3d_xs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..422e042b29ae6d2c482e2944178ae30f60acd1f1
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_x3d_xs.yaml
@@ -0,0 +1,64 @@
+_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp
+
+defaults:
+ - schema/module: video_classification_module_conf
+ - schema/module/optim: optim_conf
+ - schema/datamodule: ptv_video_classification_data_module_conf
+ - datamodule/dataloader: kinetics_classification
+ - logger: ptl
+ - datamodule/transforms: kinetics_classification_x3d_xs
+ - module/model: x3d_xs
+ - module/loss: cross_entropy
+ - module/optim: sgd
+ - module/metrics: accuracy
+ - module/lr_scheduler: cosine_with_warmup
+ - schema/trainer: trainer
+ - trainer: multi_gpu
+ - callbacks: precise_bn
+ - _self_
+
+module:
+ ensemble_method: "sum"
+ lr_scheduler:
+ max_iters: 300
+ warmup_start_lr: 0.01
+ warmup_iters: 35
+ optim:
+ lr: 0.8
+ nesterov: true
+ weight_decay: 5e-5
+
+datamodule:
+ dataloader:
+ train:
+ batch_size: 16
+ dataset:
+ clip_sampler:
+ clip_duration: 1.6
+ val:
+ batch_size: 16
+ dataset:
+ clip_sampler:
+ clip_duration: 1.6
+ test:
+ batch_size: 16
+ dataset:
+ clip_sampler:
+ clip_duration: 1.6
+
+callbacks:
+ precise_bn:
+ num_batches: 200
+
+trainer:
+ num_nodes: 8
+ gpus: 8
+ max_epochs: 300
+ sync_batchnorm: False
+ replace_sampler_ddp: False
+
+
+hydra:
+ searchpath:
+ - pkg://pytorchvideo_trainer.conf
+ - pkg://torchrecipes.core.conf
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1d428802c45a9caf18f25b5b1df2b6eed311bfae
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml
@@ -0,0 +1,43 @@
+train:
+ dataset:
+ _target_: pytorchvideo.data.Kinetics
+ data_path: ???
+ video_path_prefix: ???
+ clip_sampler:
+ _target_: pytorchvideo.data.clip_sampling.RandomClipSampler
+ clip_duration: 2.13
+
+ shuffle: True
+ batch_size: 8
+ num_workers: 8
+ pin_memory: True
+
+val:
+ dataset:
+ _target_: pytorchvideo.data.Kinetics
+ data_path: ???
+ video_path_prefix: ???
+ clip_sampler:
+ _target_: pytorchvideo.data.clip_sampling.UniformClipSampler
+ clip_duration: 2.13
+
+ shuffle: False
+ batch_size: 8
+ num_workers: 8
+ pin_memory: True
+
+test:
+ dataset:
+ _target_: pytorchvideo.data.Kinetics
+ data_path: ???
+ video_path_prefix: ???
+ clip_sampler:
+ _target_: pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler
+ clip_duration: 2.13
+ clips_per_video: 10 #num_ensemble_views
+ augs_per_clip: 3 # num_spatial_crops
+
+ shuffle: False
+ batch_size: 8
+ num_workers: 8
+ pin_memory: True
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_contrastive.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_contrastive.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4208b9a53ea25c22ce785e0f14566cddb8cf3306
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_contrastive.yaml
@@ -0,0 +1,41 @@
+train:
+ dataset:
+ _target_: pytorchvideo.data.Kinetics
+ data_path: ???
+ video_path_prefix: ???
+ clip_sampler:
+ _target_: pytorchvideo.data.clip_sampling.RandomMultiClipSampler
+ clip_duration: 2.0
+ num_clips: 2
+
+ shuffle: True
+ batch_size: 8
+ num_workers: 8
+
+val:
+ dataset:
+ _target_: pytorchvideo.data.Kinetics
+ data_path: ???
+ video_path_prefix: ???
+ clip_sampler:
+ _target_: pytorchvideo.data.clip_sampling.UniformClipSampler
+ clip_duration: 2.0
+
+ shuffle: False
+ batch_size: 8
+ num_workers: 8
+
+test:
+ dataset:
+ _target_: pytorchvideo.data.Kinetics
+ data_path: ???
+ video_path_prefix: ???
+ clip_sampler:
+ _target_: pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler
+ clip_duration: 2.0
+ clips_per_video: 10 #num_ensemble_views
+ augs_per_clip: 3 # num_spatial_crops
+
+ shuffle: False
+ batch_size: 8
+ num_workers: 8
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7924a22a5e24c19197dcd41fd2436ea69d36e20c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml
@@ -0,0 +1,70 @@
+train:
+ - _target_: pytorchvideo_trainer.datamodule.transforms.RepeatandConverttoList
+ repeat_num: 2
+ - _target_: pytorchvideo_trainer.datamodule.transforms.ApplyTransformToKeyOnList
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 16
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Permute
+ dims: [1,0,2,3]
+ - _target_: pytorchvideo.transforms.rand_augment.RandAugment
+ magnitude: 7
+ num_layers: 4
+ - _target_: pytorchvideo.transforms.Permute
+ dims: [1,0,2,3]
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.RandomResizedCrop
+ target_height: 224
+ target_width: 224
+ scale: [0.08, 1.0]
+ aspect_ratio: [0.75, 1.3333]
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ p: 0.5
+ - _target_: pytorchvideo.transforms.Permute
+ dims: [1,0,2,3]
+ - _target_: pytorchvideo_trainer.datamodule.rand_erase_transform.RandomErasing
+ probability: 0.25
+ mode: "pixel"
+ max_count: 1
+ num_splits: 1
+ device: "cpu"
+ - _target_: pytorchvideo.transforms.Permute
+ dims: [1,0,2,3]
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+val:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 16
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 224
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+test:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 16
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 224
+ key: video
+ - _target_: pytorchvideo.transforms.UniformCropVideo
+ size: 224
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cc69c698739884f9da02f01e365893b99c0f7e74
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml
@@ -0,0 +1,51 @@
+train:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.RandomShortSideScale
+ min_size: 256
+ max_size: 320
+ - _target_: torchvision.transforms.RandomCrop
+ size: 224
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ p: 0.5
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+val:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ - _target_: torchvision.transforms.CenterCrop
+ size: 256
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+test:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ key: video
+ - _target_: pytorchvideo.transforms.UniformCropVideo
+ size: 256
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slowfast.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slowfast.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5388fb052d588ea689a838091319a99b474763de
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slowfast.yaml
@@ -0,0 +1,60 @@
+train:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 32
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.RandomShortSideScale
+ min_size: 256
+ max_size: 320
+ - _target_: torchvision.transforms.RandomCrop
+ size: 224
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ p: 0.5
+ - _target_: pytorchvideo_trainer.datamodule.transforms.SlowFastPackPathway
+ alpha: 4
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+val:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 32
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ - _target_: torchvision.transforms.CenterCrop
+ size: 256
+ - _target_: pytorchvideo_trainer.datamodule.transforms.SlowFastPackPathway
+ alpha: 4
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+test:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 32
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ key: video
+ - _target_: pytorchvideo.transforms.UniformCropVideo
+ size: 256
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo_trainer.datamodule.transforms.SlowFastPackPathway
+ alpha: 4
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_x3d_xs.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_x3d_xs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e80f20e66fe4715bc1c786b35d9a5e1a62692d9e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_x3d_xs.yaml
@@ -0,0 +1,51 @@
+train:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 4
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.RandomShortSideScale
+ min_size: 182
+ max_size: 228
+ - _target_: torchvision.transforms.RandomCrop
+ size: 160
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ p: 0.5
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+val:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 4
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 182
+ - _target_: torchvision.transforms.CenterCrop
+ size: 182
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+test:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 4
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 182
+ key: video
+ - _target_: pytorchvideo.transforms.UniformCropVideo
+ size: 182
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_contrastive.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_contrastive.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..238971871ac7cd8be886c4cccfae983c4f30c350
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_contrastive.yaml
@@ -0,0 +1,56 @@
+train:
+ - _target_: pytorchvideo_trainer.datamodule.transforms.ApplyTransformToKeyOnList
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo_trainer.datamodule.transforms.ColorJitterVideoSSl
+ bri_con_sat: [0.6, 0.6, 0.6]
+ hue: 0.15
+ p_color_jitter: 0.8
+ p_convert_gray: 0.2
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.RandomResizedCrop
+ target_height: 224
+ target_width: 224
+ scale: [0.2, 0.766]
+ aspect_ratio: [0.75, 1.3333]
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ p: 0.5
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+val:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ - _target_: torchvision.transforms.CenterCrop
+ size: 256
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+test:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ key: video
+ - _target_: pytorchvideo.transforms.UniformCropVideo
+ size: 256
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_moco_v2.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_moco_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2919905a5779f2024b700ad259178c2cb9943bf9
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_moco_v2.yaml
@@ -0,0 +1,56 @@
+train:
+ - _target_: pytorchvideo_trainer.datamodule.transforms.ApplyTransformToKeyOnList
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo_trainer.datamodule.transforms.ColorJitterVideoSSl
+ bri_con_sat: [0.4, 0.4, 0.4]
+ hue: 0.4
+ p_color_jitter: 0.8
+ p_convert_gray: 0.2
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.RandomResizedCrop
+ target_height: 224
+ target_width: 224
+ scale: [0.2, 0.766]
+ aspect_ratio: [0.75, 1.3333]
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ p: 0.5
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+val:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ - _target_: torchvision.transforms.CenterCrop
+ size: 256
+ key: video
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+test:
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 8
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 256
+ key: video
+ - _target_: pytorchvideo.transforms.UniformCropVideo
+ size: 256
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/logger/ptl.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/logger/ptl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..352afc06d0c0838233186d04c817b8b04d4d66fe
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/logger/ptl.yaml
@@ -0,0 +1,4 @@
+_target_: pytorch_lightning.loggers.TensorBoardLogger
+save_dir: ???
+name: default
+version: null
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/moco_v2_train_app_conf.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/moco_v2_train_app_conf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d464b742674d70506d2469f546b184b5a9250b5b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/moco_v2_train_app_conf.yaml
@@ -0,0 +1,31 @@
+_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp
+
+defaults:
+ - schema/module: moco_v2_module_conf
+ - schema/module/optim: optim_conf
+ - schema/datamodule: ptv_video_classification_data_module_conf
+ - datamodule/dataloader: kinetics_contrastive
+ - logger: ptl
+ - datamodule/transforms: kinetics_moco_v2
+ - module/knn_memory: kinetics_k400
+ - module/model: slow_r50_moco_v2
+ - module/loss: contrastive
+ - module/optim: sgd_ssl
+ - module/metrics: accuracy
+ - schema/trainer: trainer
+ - trainer: cpu
+ - callbacks: null
+ - _self_
+trainer:
+ sync_batchnorm: false # set this to true for training
+
+module:
+ dim: ${module.model.backbone_embed_dim}
+ k: 65536
+ batch_shuffle: true
+ local_shuffle_bn: true
+
+hydra:
+ searchpath:
+ - pkg://pytorchvideo_trainer.conf
+ - pkg://torchrecipes.core.conf
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/knn_memory/kinetics_k400.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/knn_memory/kinetics_k400.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..edecec617c70f550700b5ee6e4244188fdf4cfcb
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/knn_memory/kinetics_k400.yaml
@@ -0,0 +1,7 @@
+_target_: pytorchvideo_trainer.module.ssl_helper.KnnMemory
+temperature: ${module.loss.temperature}
+dim: ${module.model.backbone_embed_dim}
+length: 239975
+downstream_classes: 400
+knn_k: 200
+momentum: 1.0
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/contrastive.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/contrastive.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..669a328d64bbe10f3ede3961f313979a377fb48a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/contrastive.yaml
@@ -0,0 +1,2 @@
+_target_: pytorchvideo_trainer.module.losses.ContrastiveLoss
+temperature: 0.1
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/cross_entropy.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/cross_entropy.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f381cd87d2815c8b1fd3ebcf4951ccbea06cca84
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/cross_entropy.yaml
@@ -0,0 +1,2 @@
+# @package _group_
+_target_: torch.nn.CrossEntropyLoss
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/nt_xent.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/nt_xent.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..11df106a5b5ea06ade3caf4f9cfc37f4a2ca683e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/nt_xent.yaml
@@ -0,0 +1,3 @@
+# @package _group_
+_target_: pytorchvideo_trainer.module.losses.NtxentLoss
+temperature: 0.1
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/similarity.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/similarity.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c483cfd73cea5493d961f5ba0d903149626e852b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/similarity.yaml
@@ -0,0 +1,3 @@
+# @package _group_
+_target_: pytorchvideo_trainer.module.losses.SimilarityLoss
+temperature: 0.1
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/soft_cross_entropy.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/soft_cross_entropy.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2df3319cc0217226ec0cd66541b8df5d0b4ba995
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/soft_cross_entropy.yaml
@@ -0,0 +1,2 @@
+# @package _group_
+_target_: pytorchvideo_trainer.module.losses.SoftTargetCrossEntropy
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/lr_scheduler/cosine_with_warmup.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/lr_scheduler/cosine_with_warmup.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b307f9fbcc0008ce40959ac72c6fce32a617cfa4
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/lr_scheduler/cosine_with_warmup.yaml
@@ -0,0 +1,7 @@
+lr_policy: 'cosine'
+cosine_after_warmup: False
+cosine_end_lr: 0
+warmup_iters: 34
+warmup_start_lr: 0.01
+max_iters: ${trainer.max_epochs}
+lr: ${module.optim.lr}
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/accuracy.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/accuracy.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8bbd4268953b53337c25ae6fe9f89871269cd833
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/accuracy.yaml
@@ -0,0 +1,8 @@
+- name: accuracy_top1
+ config:
+ _target_: torchmetrics.Accuracy
+ top_k: 1
+- name: accuracy_top5
+ config:
+ _target_: torchmetrics.Accuracy
+ top_k: 5
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/average_precision.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/average_precision.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd8ad7a675b0eb5c45bd82d0d20cd9a27a515c85
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/average_precision.yaml
@@ -0,0 +1,3 @@
+- name: average_precision
+ config:
+ _target_: torchmetrics.AveragePrecision
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..00e30353ba11053a8befb1397e8807eac7dea247
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml
@@ -0,0 +1,2 @@
+_target_: pytorchvideo_trainer.module.video_classification.create_classification_model_from_lightning
+checkpoint_path: ???
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..422bdbdf58ab7266cd488fe6867c50b2a4df14d2
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml
@@ -0,0 +1,5 @@
+_target_: pytorchvideo_trainer.module.video_classification.create_classification_model_from_modelzoo
+checkpoint_path: manifold://fair_logging/tree/kalyanv/hub_models/SLOW_8x8_R50.pyth
+model:
+ _target_: pytorchvideo.models.hub.resnet.slow_r50
+ pretrained: False
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_ssl_checkpoint.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_ssl_checkpoint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..21c3b9d028421a675dce4964e5919643ad5e32d7
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_ssl_checkpoint.yaml
@@ -0,0 +1,11 @@
+_target_: pytorchvideo_trainer.module.ssl_helper.create_classification_model_from_ssl_checkpoint
+ssl_checkpoint_path: null
+checkpoint_type: simclr
+mlp:
+ _target_: pytorchvideo_trainer.module.byol.create_mlp_util
+ dim_in: null
+ dim_out: 400
+ mlp_dim: 256
+ num_layers: 1
+ norm: null
+detach_backbone: true
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/mvit_base_16x4.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/mvit_base_16x4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fa5d958a0ea464a39e34b20e03922e306c4cad05
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/mvit_base_16x4.yaml
@@ -0,0 +1,32 @@
+_target_: pytorchvideo.models.vision_transformers.create_multiscale_vision_transformers
+spatial_size: 224
+temporal_size: 16
+cls_embed_on: True
+sep_pos_embed: True
+depth: 16
+norm: "layernorm"
+input_channels: 3
+patch_embed_dim: 96
+conv_patch_embed_kernel: [3, 7, 7]
+conv_patch_embed_stride: [2, 4, 4]
+conv_patch_embed_padding: [1, 3, 3]
+enable_patch_embed_norm: False
+use_2d_patch: False
+# Attention block config.
+num_heads: 1
+mlp_ratio: 4.0
+qkv_bias: True
+dropout_rate_block: 0.0
+droppath_rate_block: 0.2
+pooling_mode: "conv"
+pool_first: False
+embed_dim_mul: [[1, 2.0], [3, 2.0], [14, 2.0]]
+atten_head_mul: [[1, 2.0], [3, 2.0], [14, 2.0]]
+pool_q_stride_size: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]]
+pool_kv_stride_size: null
+pool_kv_stride_adaptive: [1, 8, 8]
+pool_kvq_kernel: [3, 3, 3]
+# Head config.
+head_dropout_rate: 0.5
+head_activation: null
+head_num_classes: 400
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..05ba9a58f5970c3b70346a5f318ea535fff94c15
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50.yaml
@@ -0,0 +1,7 @@
+_target_: pytorchvideo.models.resnet.create_resnet
+input_channel: 3
+model_depth: 50
+model_num_class: 400
+dropout_rate: 0.5
+stem_conv_kernel_size: [1, 7, 7]
+head_pool_kernel_size: [8, 7, 7]
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_byol.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_byol.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..daeb90bcb785018498916210d706758d2794d746
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_byol.yaml
@@ -0,0 +1,3 @@
+_target_: pytorchvideo_trainer.module.byol.create_byol_resnet_50
+backbone_embed_dim: 128
+mmt: 0.996
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_moco_v2.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_moco_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aef3defe3faab8b93934656ba201a362edd79a8c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_moco_v2.yaml
@@ -0,0 +1,3 @@
+_target_: pytorchvideo_trainer.module.moco_v2.create_moco_resnet_50
+backbone_embed_dim: 128
+mmt: 0.994
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_simclr.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_simclr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4e9593f51ce60b3cbe7c38d6bd42026850e80b7
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_simclr.yaml
@@ -0,0 +1,4 @@
+_target_: pytorchvideo_trainer.module.simclr.create_simclr_resnet_50
+backbone_embed_dim: 128
+mlp_depth: 1
+mlp_inner_dim: 2048
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slowfast_r50.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slowfast_r50.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..08e5335f88b48d2f235b0ed148ef561a9037ae24
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slowfast_r50.yaml
@@ -0,0 +1,6 @@
+_target_: pytorchvideo.models.slowfast.create_slowfast
+input_channels: [3,3]
+model_depth: 50
+model_num_class: 400
+dropout_rate: 0.5
+slowfast_fusion_conv_kernel_size: [7, 1, 1]
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/x3d_xs.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/x3d_xs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c645cda90b6067c29eea31e9cb33e911b3f51a60
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/x3d_xs.yaml
@@ -0,0 +1,8 @@
+_target_: pytorchvideo.models.x3d.create_x3d
+input_channel: 3
+model_num_class: 400
+dropout_rate: 0.5
+input_clip_length: 4
+input_crop_size: 160
+depth_factor: 2.2
+head_activation: null
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adam.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adam.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a9239a738db4eaadbceb7176d8a5e4cd4fc565d3
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adam.yaml
@@ -0,0 +1,3 @@
+method: 'adam'
+lr: 0.001
+weight_decay: 0
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adamw.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adamw.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ef29b4dabdf7c27084b9ced3e3b05cc5409b8d0a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adamw.yaml
@@ -0,0 +1,3 @@
+method: 'adamw'
+lr: 0.001
+weight_decay: 0.01
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..011eaef57d58ef007eb950115068045970f1ddda
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd.yaml
@@ -0,0 +1,5 @@
+method: 'sgd'
+lr: 0.1
+weight_decay: 1e-4
+momentum: 0.9
+nesterov: True
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd_ssl.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd_ssl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a0dc88342243d9c86872a5a1d9ec9ab1a5483f19
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd_ssl.yaml
@@ -0,0 +1,5 @@
+method: 'sgd'
+lr: 0.6
+weight_decay: 1e-6
+momentum: 0.9
+nesterov: True
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/simclr_train_app_conf.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/simclr_train_app_conf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..319d1c3f17c2e3882fd5890166a50f0534c4310a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/simclr_train_app_conf.yaml
@@ -0,0 +1,25 @@
+_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp
+
+defaults:
+ - schema/module: simclr_module_conf
+ - schema/module/optim: optim_conf
+ - schema/datamodule: ptv_video_classification_data_module_conf
+ - datamodule/dataloader: kinetics_contrastive
+ - logger: ptl
+ - datamodule/transforms: kinetics_moco_v2
+ - module/knn_memory: kinetics_k400
+ - module/model: slow_r50_simclr
+ - module/loss: nt_xent
+ - module/optim: sgd_ssl
+ - module/metrics: accuracy
+ - schema/trainer: trainer
+ - trainer: cpu
+ - callbacks: null
+ - _self_
+trainer:
+ sync_batchnorm: false # set this to true for training
+
+hydra:
+ searchpath:
+ - pkg://pytorchvideo_trainer.conf
+ - pkg://torchrecipes.core.conf
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/submitit_conf/fair_cluster.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/submitit_conf/fair_cluster.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87ca47f3f3f9284170e1dd9093eb975bdb14f0d8
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/submitit_conf/fair_cluster.yaml
@@ -0,0 +1,9 @@
+# @package _group_
+log_save_dir: null
+name: "ptv_trainer_job"
+time: "72:00:00"
+cpus_per_task: 10
+partition: "learnlab"
+mem: "470GB"
+constraint: "volta32gb"
+mode: "prod"
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/cpu.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/cpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..91fea054da21d7fabc21dfdbac7c79c359b545da
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/cpu.yaml
@@ -0,0 +1,2 @@
+# @package _group_
+max_epochs: 1
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/multi_gpu.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/multi_gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9d406fbdd50ee22e0704adfdd7273e83253b2958
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/multi_gpu.yaml
@@ -0,0 +1,6 @@
+# @package _group_
+gpus: 8
+strategy: ddp
+max_epochs: 1
+num_sanity_val_steps: 0
+log_every_n_steps: 10
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/single_gpu.yaml b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/single_gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..431c0e35794bae877e9ce5bdd32887f7d4e0fa4d
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/single_gpu.yaml
@@ -0,0 +1,3 @@
+# @package _group_
+gpus: 1
+max_epochs: 1
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/__init__.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..729d15104e9d4135c456fdd8d731a6fe74496ec9
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .datamodule import PyTorchVideoDataModule # noqa
+
+
+__all__ = [
+ "PyTorchVideoDataModule",
+]
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/collators.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/collators.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e14275ebbfa9d48b6d4274074b5af88655c4d13
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/collators.py
@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import Any, Callable, Dict, List
+
+from torch.utils.data._utils.collate import default_collate
+
+
+# pyre-ignore[2]
+def multiple_samples_collate(batch: List[Dict[str, List[Any]]]) -> Dict[str, Any]:
+ """
+ Collate function for repeated augmentation. Each instance in the batch has
+ more than one sample.
+
+ To be used when working with,
+ `pytorchvideo_trainer.datamodule.transforms.RepeatandConverttoList`
+ """
+ batch_dict = {}
+ for k in batch[0].keys():
+ v_iter = []
+ for sample_dict in batch:
+ v_iter += sample_dict[k]
+ batch_dict[k] = default_collate(v_iter)
+
+ return batch_dict
+
+
+# pyre-ignore[24]
+_COLLATORS: Dict[str, Callable] = {
+ "multiple_samples_collate": multiple_samples_collate,
+}
+
+
+def build_collator_from_name(name: str) -> Callable: # pyre-ignore[24]
+ """
+ A utility function that returns the function handles to specific collators
+ in `_COLLATORS` dictionary object based on the queried key. Used in
+ `pytorchvideo_trainer.datamodule.PyTorchVideoDataModule`, etc.
+
+ Arg:
+ name (str): name of the qurried collators. The key should be present in
+ `_COLLATORS` dictionary object
+ """
+ assert (
+ name in _COLLATORS
+ ), f"Inavalid Collator method. Available methods are {_COLLATORS.keys()}"
+ return _COLLATORS[name]
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..b52b5cf18472fffff62ff6995fe5ee5d81608ddd
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py
@@ -0,0 +1,226 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional
+
+import hydra
+import pytorch_lightning as pl
+import pytorchvideo.data
+import torch
+from hydra.core.config_store import ConfigStore
+
+# @manual "fbsource//third-party/pypi/omegaconf:omegaconf"
+from omegaconf import MISSING
+from pytorchvideo_trainer.datamodule.transforms import build_transforms
+from torch.utils.data import DataLoader, RandomSampler
+from torch.utils.data.distributed import DistributedSampler
+from torchrecipes.core.conf import DataModuleConf
+from torchrecipes.utils.config_utils import get_class_name_str
+
+
+class PyTorchVideoDataModule(pl.LightningDataModule):
+ """
+ A PyTorch-Lightning DataModule module supporting all the dataloaders
+ in PyTorchVideo for different phases (train, validation and testing) of
+ Lightning tranining.
+
+ Supports loading any aribtrary iterable and map-style PyTorchVideo dataset
+ upon following the config schema detailed below.
+
+ Args:
+ dataloader (DataLoaderConf):
+ An OmegaConf / Hydra Config object consisting of dataloder
+ config for each phase i.e, train, val and test.
+
+ The Hydra schema for this config is as defined in
+ `pytorchvideo_trainer.datamodule.datamodule.DataLoaderConf`
+
+ One such example config can be found at
+ `pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml`
+
+ transforms (TransformsConf):
+ An OmegaConf / Hydra Config object consisting of transforms
+ config for each phase i.e, train, val and test.
+
+ The Hydra schema for this config is as defined in
+ `pytorchvideo_trainer.datamodule.datamodule.TransformsConf`
+
+ One such example config used for Resnet50 video model traning can be found at
+ `pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml`
+ """
+
+ def __init__(
+ self,
+ dataloader: DataLoaderConf,
+ transforms: TransformsConf,
+ ) -> None:
+ super().__init__()
+ self.config: Dict[str, Any] = {
+ "train": dataloader.train,
+ "val": dataloader.val,
+ "test": dataloader.test,
+ }
+ self.transforms: Dict[str, Any] = {
+ "train": build_transforms(transforms.train),
+ "val": build_transforms(transforms.val),
+ "test": build_transforms(transforms.test),
+ }
+ self.datasets: dict[str, Any] = {"train": None, "val": None, "test": None}
+
+ def setup(self, stage: Optional[str] = None) -> None:
+
+ if stage == "fit" or stage is None:
+ self.datasets["train"] = self._get_dataset(
+ phase="train", transforms=self.transforms["train"]
+ )
+ self.datasets["val"] = self._get_dataset(
+ phase="val", transforms=self.transforms["val"]
+ )
+ if stage == "test" or stage is None:
+ self.datasets["test"] = self._get_dataset(
+ phase="test", transforms=self.transforms["test"]
+ )
+
+ def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
+ """
+ Defines the train DataLoader that the PyTorch Lightning Trainer uses.
+ """
+ if (
+ self.trainer
+ and torch.distributed.is_available()
+ and torch.distributed.is_initialized()
+ ):
+ self.datasets["train"].video_sampler.set_epoch(self.trainer.current_epoch)
+
+ return self._get_dataloader("train")
+
+ def val_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
+ """
+ Defines the val DataLoader that the PyTorch Lightning Trainer uses.
+ """
+ return self._get_dataloader("val")
+
+ def test_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
+ """
+ Defines the test DataLoader that the PyTorch Lightning Trainer uses.
+ """
+ return self._get_dataloader("test")
+
+ def _get_dataloader(self, phase: str) -> DataLoader:
+ assert self.datasets[phase] is not None, "Failed to get the {} dataset!".format(
+ phase
+ )
+
+ if isinstance(self.datasets[phase], torch.utils.data.IterableDataset):
+ return torch.utils.data.DataLoader(
+ self.datasets[phase],
+ batch_size=self.config[phase].batch_size,
+ num_workers=self.config[phase].num_workers,
+ pin_memory=self.config[phase].pin_memory,
+ drop_last=self.config[phase].drop_last,
+ collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn),
+ worker_init_fn=hydra.utils.instantiate(
+ self.config[phase].worker_init_fn
+ ),
+ )
+ else:
+ sampler = None
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ logging.info(
+ "Distributed Environmnet detected, using DistributedSampler for dataloader."
+ )
+ sampler = DistributedSampler(self.datasets[phase])
+
+ return torch.utils.data.DataLoader(
+ self.datasets[phase],
+ batch_size=self.config[phase].batch_size,
+ num_workers=self.config[phase].num_workers,
+ pin_memory=self.config[phase].pin_memory,
+ drop_last=self.config[phase].drop_last,
+ sampler=sampler,
+ shuffle=(False if sampler else self.config[phase].shuffle),
+ collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn),
+ worker_init_fn=hydra.utils.instantiate(
+ self.config[phase].worker_init_fn
+ ),
+ )
+
+ def _get_dataset(
+ self,
+ phase: str,
+ transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ ) -> pytorchvideo.data.LabeledVideoDataset:
+
+ video_sampler = RandomSampler
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ logging.info(
+ "Distributed Environmnet detected, using DistributedSampler for dataset."
+ )
+ video_sampler = DistributedSampler
+
+ dataset = hydra.utils.instantiate(
+ self.config[phase].dataset,
+ transform=transforms,
+ video_sampler=video_sampler,
+ )
+ return dataset
+
+
+@dataclass
+class PhaseDataLoaderConf:
+
+ num_workers: int = 0
+ pin_memory: bool = False
+ drop_last: bool = False
+ batch_size: int = MISSING
+ shuffle: bool = True
+
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
+ collate_fn: Optional[Any] = None
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
+ worker_init_fn: Optional[Any] = None
+
+ ## Dataset Related
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
+ dataset: Any = MISSING
+
+
+@dataclass
+class DataLoaderConf:
+ train: PhaseDataLoaderConf = MISSING
+ val: PhaseDataLoaderConf = MISSING
+ test: PhaseDataLoaderConf = MISSING
+
+
+@dataclass
+class TransformsConf:
+
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
+ train: List[Any] = MISSING
+
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
+ val: List[Any] = MISSING
+
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
+ test: List[Any] = MISSING
+
+
+@dataclass
+class VideoClassificationDataModuleConf(DataModuleConf):
+ _target_: str = get_class_name_str(PyTorchVideoDataModule)
+
+ dataloader: DataLoaderConf = MISSING
+ transforms: TransformsConf = MISSING
+
+
+cs = ConfigStore()
+
+cs.store(
+ group="schema/datamodule",
+ name="ptv_video_classification_data_module_conf",
+ node=VideoClassificationDataModuleConf,
+ package="datamodule",
+)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/rand_erase_transform.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/rand_erase_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecae54fd8fa8b4cd7c7c66323c75c2eec383ad6c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/rand_erase_transform.py
@@ -0,0 +1,196 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""
+This implementation is based on
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
+pulished under an Apache License 2.0.
+COMMENT FROM ORIGINAL:
+Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
+Copyright Zhun Zhong & Liang Zheng
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import random
+from typing import Optional, Tuple
+
+import torch
+
+
+def _get_pixels(
+ per_pixel: bool,
+ rand_color: bool,
+ patch_size: Tuple[int],
+ dtype: torch.dtype = torch.float32,
+ device: str = "cuda",
+) -> torch.Tensor:
+ """
+ A utility function that generates image patches for RandomErasing transform
+ """
+ if per_pixel:
+ return torch.empty(patch_size, dtype=dtype, device=device).normal_()
+ elif rand_color:
+ return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
+ else:
+ return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
+
+
+class RandomErasing:
+ """
+ This variant of RandomErasing is intended to be applied to a video tensor i.e,
+ batch of images after it has been normalized by dataset mean and std.
+
+ Randomly selects a rectangle region in an image and erases its pixels.
+ 'Random Erasing Data Augmentation' by Zhong et al.
+ See https://arxiv.org/pdf/1708.04896.pdf
+
+ Args:
+ probability (float): Probability that the Random Erasing operation will be performed.
+ min_area (float): Minimum percentage of erased area wrt input image area.
+ max_area (float): Maximum percentage of erased area wrt input image area.
+ min_aspect (float): Minimum aspect ratio of erased area.
+ mode (str): pixel color mode, one of 'const', 'rand', or 'pixel'
+ 'const' - erase block is constant color of 0 for all channels
+ 'rand' - erase block is same per-channel random (normal) color
+ 'pixel' - erase block is per-pixel random (normal) color
+ max_count (int): maximum number of erasing blocks per image, area per box is scaled by
+ count. Per-image count is randomly chosen between 1 and this value.
+ min_count (int): minimum number of erasing blocks per image, area per box is scaled by
+ count. Per-image count is randomly chosen between 1 and this value.
+ device (str): Device to perform the transform on.
+ """
+
+ def __init__(
+ self,
+ probability: float = 0.5,
+ min_area: float = 0.02,
+ max_area: float = 1 / 3,
+ min_aspect: float = 0.3,
+ max_aspect: Optional[float] = None,
+ mode: str = "const",
+ min_count: int = 1,
+ max_count: Optional[int] = None,
+ num_splits: int = 0,
+ device: str = "cuda",
+ cube: bool = True,
+ ) -> None:
+ self.probability = probability
+ self.min_area = min_area
+ self.max_area = max_area
+ max_aspect = max_aspect or 1 / min_aspect
+ self.log_aspect_ratio: Tuple[float, float] = (
+ math.log(min_aspect),
+ math.log(max_aspect),
+ )
+ self.min_count = min_count
+ self.max_count: int = max_count or min_count
+ self.num_splits = num_splits
+ mode = mode.lower()
+ self.rand_color: bool = False
+ self.per_pixel: bool = False
+ self.cube = cube
+ if mode == "rand":
+ self.rand_color = True # per block random normal
+ elif mode == "pixel":
+ self.per_pixel = True # per pixel random normal
+ else:
+ assert not mode or mode == "const"
+ self.device = device
+
+ def _erase(
+ self, img: torch.Tensor, chan: int, height: int, width: int, dtype: torch.dtype
+ ) -> None:
+ if random.random() > self.probability:
+ return
+ area = height * width
+ count = (
+ self.min_count
+ if self.min_count == self.max_count
+ else random.randint(self.min_count, self.max_count)
+ )
+ for _ in range(count):
+ for _ in range(10):
+ target_area = (
+ random.uniform(self.min_area, self.max_area) * area / count
+ )
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < width and h < height:
+ top = random.randint(0, height - h)
+ left = random.randint(0, width - w)
+ img[:, top : top + h, left : left + w] = _get_pixels(
+ self.per_pixel,
+ self.rand_color,
+ (chan, h, w), # pyre-ignore[6]
+ dtype=dtype,
+ device=self.device,
+ )
+ break
+
+ def _erase_cube(
+ self,
+ video: torch.Tensor,
+ batch_start: int,
+ batch_size: int,
+ chan: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ ) -> None:
+ if random.random() > self.probability:
+ return
+ area = height * width
+ count = (
+ self.min_count
+ if self.min_count == self.max_count
+ else random.randint(self.min_count, self.max_count)
+ )
+ for _ in range(count):
+ for _ in range(100):
+ target_area = (
+ random.uniform(self.min_area, self.max_area) * area / count
+ )
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < width and h < height:
+ top = random.randint(0, height - h)
+ left = random.randint(0, width - w)
+ for i in range(batch_start, batch_size):
+ img_instance = video[i]
+ img_instance[:, top : top + h, left : left + w] = _get_pixels(
+ self.per_pixel,
+ self.rand_color,
+ (chan, h, w), # pyre-ignore[6]
+ dtype=dtype,
+ device=self.device,
+ )
+ break
+
+ def __call__(self, frames: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `channel` x `num frames` x `height` x `width`.
+ Returns:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `channel` x `num frames` x `height` x `width`.
+ """
+ # Expects frames of shape T, C, H, W
+ batch_size, chan, height, width = frames.size()
+ # skip first slice of batch if num_splits is set (for clean portion of samples)
+ batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
+ if self.cube:
+ self._erase_cube(
+ frames,
+ batch_start,
+ batch_size,
+ chan,
+ height,
+ width,
+ frames.dtype,
+ )
+ else:
+ for i in range(batch_start, batch_size):
+ self._erase(frames[i], chan, height, width, frames.dtype)
+ return frames
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/transforms.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f6969bd19f565c0584e3e4dd14384444caa090e
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/transforms.py
@@ -0,0 +1,287 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import random
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence
+
+import hydra
+import torch
+import torchvision
+from PIL import Image, ImageFilter
+from pytorchvideo.transforms import MixVideo
+from torchvision.transforms import Compose
+
+
+def build_transforms(transforms_config: Iterable[Mapping[str, Any]]) -> Compose:
+ """
+ A utility function to build data transforsm from a list of Hydra/Omega Conf
+ objects. This utility method is called by
+ `pytorchvideo_trainer.datamodule.PyTorchVideoDataModule` class to build a
+ sequence of transforms applied during each phase(train, val and test).
+
+ Uses torchvision.transforms.Compose to build a seuquence of transforms.
+
+ Examples of config objects used by this method can be found in,
+ `pytorchvide_trainer/conf/datamodule/transforms/`
+
+ Args:
+ transforms_config: A list of hydra config objects wherein, each element
+ represents config associated with a single transforms.
+
+ An example of this would be,
+ ```
+ - _target_: pytorchvideo.transforms.ApplyTransformToKey
+ transform:
+ - _target_: pytorchvideo.transforms.UniformTemporalSubsample
+ num_samples: 16
+ - _target_: pytorchvideo.transforms.Div255
+ - _target_: pytorchvideo.transforms.Normalize
+ mean: [0.45, 0.45, 0.45]
+ std: [0.225, 0.225, 0.225]
+ - _target_: pytorchvideo.transforms.ShortSideScale
+ size: 224
+ key: video
+ - _target_: pytorchvideo.transforms.UniformCropVideo
+ size: 224
+ - _target_: pytorchvideo.transforms.RemoveKey
+ key: audio
+ ```
+ """
+ transform_list = [build_single_transform(config) for config in transforms_config]
+ transform = Compose(transform_list)
+ return transform
+
+
+def build_single_transform(config: Mapping[str, Any]) -> Callable[..., object]:
+ """
+ A utility method to build a single transform from hydra / omega conf objects.
+
+ If the key "transform" is present in the give config, it recursively builds
+ and composes transforms using the `torchvision.transforms.Compose` method.
+ """
+ config = dict(config)
+ if "transform" in config:
+ assert isinstance(config["transform"], Sequence)
+ transform_list = [
+ build_single_transform(transform) for transform in config["transform"]
+ ]
+ transform = Compose(transform_list)
+ config.pop("transform")
+ return hydra.utils.instantiate(config, transform=transform)
+ return hydra.utils.instantiate(config)
+
+
+class ApplyTransformToKeyOnList:
+ """
+ Applies transform to key of dictionary input wherein input is a list
+
+ Args:
+ key (str): the dictionary key the transform is applied to
+ transform (callable): the transform that is applied
+
+ Example:
+ >>> transforms.ApplyTransformToKeyOnList(
+ >>> key='video',
+ >>> transform=UniformTemporalSubsample(num_video_samples),
+ >>> )
+ """
+
+ def __init__(self, key: str, transform: Callable) -> None: # pyre-ignore[24]
+ self._key = key
+ self._transform = transform
+
+ def __call__(
+ self, x: Dict[str, List[torch.Tensor]]
+ ) -> Dict[str, List[torch.Tensor]]:
+ x[self._key] = [self._transform(a) for a in x[self._key]]
+ return x
+
+
+class SlowFastPackPathway:
+ """
+ Transform for converting a video clip into a list of 2 clips with
+ different temporal granualirity as needed by the SlowFast video
+ model.
+
+ For more details, refere to the paper,
+ Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
+ "SlowFast networks for video recognition."
+ https://arxiv.org/pdf/1812.03982.pdf
+
+ Args:
+ alpha (int): Number of frames to sub-sample from the given clip
+ to create the second clip.
+ """
+
+ def __init__(self, alpha: int) -> None:
+ super().__init__()
+ self.alpha = alpha
+
+ def __call__(self, frames: torch.Tensor) -> List[torch.Tensor]:
+ """
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `channel` x `num frames` x `height` x `width`.
+ Returns:
+ frame_list (list): list of tensors with the dimension of
+ `channel` x `num frames` x `height` x `width`.
+ """
+ fast_pathway = frames
+ # Perform temporal sampling from the fast pathway.
+ slow_pathway = torch.index_select(
+ frames,
+ 1,
+ torch.linspace(
+ 0, frames.shape[1] - 1, frames.shape[1] // self.alpha
+ ).long(),
+ )
+ frame_list = [slow_pathway, fast_pathway]
+ return frame_list
+
+
+class RepeatandConverttoList:
+ """
+ An utility transform that repeats each value in a
+ key, value-style minibatch and replaces it with a list of values.
+
+ Useful for performing multiple augmentations.
+ An example such usecase can be found in
+ `pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml`
+
+ Args:
+ repead_num (int): Number of times to repeat each value.
+ """
+
+ def __init__(self, repeat_num: int) -> None:
+ super().__init__()
+ self.repeat_num = repeat_num
+
+ # pyre-ignore[3]
+ def __call__(self, sample_dict: Dict[str, Any]) -> Dict[str, List[Any]]:
+ for k, v in sample_dict.items():
+ sample_dict[k] = self.repeat_num * [v]
+ return sample_dict
+
+
+class MixVideoBatchWrapper:
+ def __init__(
+ self,
+ mixup_alpha: float,
+ cutmix_prob: float,
+ cutmix_alpha: float,
+ label_smoothing: float,
+ ) -> None:
+ """
+ A wrapper for MixVideo (CutMix or Mixup) tranform in pytorchvideo.transforms.
+ Extends the MixVideo transform to work on a batch dictionary objects.
+
+ The dictionary object should consist of keys "video" and "label" representing
+ video clips and their associated labels.
+ """
+
+ self.mix_video_transform = MixVideo(
+ mixup_alpha=mixup_alpha,
+ cutmix_prob=cutmix_prob,
+ cutmix_alpha=cutmix_alpha,
+ label_smoothing=label_smoothing,
+ )
+
+ def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]:
+
+ batch["video"], batch["label"] = self.mix_video_transform(
+ batch["video"], batch["label"]
+ )
+ return batch
+
+
+class ColorJitterVideoSSl:
+ """
+ A custom sequence of transforms that randomly performs Color jitter,
+ Gaussian Blur and Grayscaling on the given clip.
+
+ Particularly useful for the SSL tasks like SimCLR, MoCoV2, BYOL, etc.
+
+ Args:
+ bri_con_sat (list[float]): A list of 3 floats reprsenting brightness,
+ constrast and staturation coefficients to use for the
+ `torchvision.transforms.ColorJitter` transform.
+ hue (float): Heu value to use in the `torchvision.transforms.ColorJitter`
+ transform.
+ p_color_jitter (float): The probability with which the Color jitter transform
+ is randomly applied on the given clip.
+ p_convert_gray (float): The probability with which the given clip is randomly
+ coverted into grayscale.
+ p_gaussian_blur (float): The probability with which the Gaussian transform
+ is randomly applied on the given clip.
+ gaussian_blur_sigma (list[float]): A list of 2 floats with in which
+ the blur radius is randomly sampled for Gaussian blur transform.
+ """
+
+ def __init__(
+ self,
+ bri_con_sat: List[float],
+ hue: float,
+ p_color_jitter: float,
+ p_convert_gray: float,
+ p_gaussian_blur: float = 0.5,
+ gaussian_blur_sigma: List[float] = (0.1, 2.0),
+ ) -> None:
+
+ self.color_jitter = torchvision.transforms.Compose(
+ [
+ torchvision.transforms.ToPILImage(),
+ torchvision.transforms.RandomApply(
+ [
+ torchvision.transforms.ColorJitter(
+ bri_con_sat[0], bri_con_sat[1], bri_con_sat[2], hue
+ )
+ ],
+ p=p_color_jitter,
+ ),
+ torchvision.transforms.RandomGrayscale(p=p_convert_gray),
+ torchvision.transforms.RandomApply(
+ [GaussianBlur(gaussian_blur_sigma)], p=p_gaussian_blur
+ ),
+ torchvision.transforms.ToTensor(),
+ ]
+ )
+
+ def __call__(self, frames: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `channel` x `num frames` x `height` x `width`.
+ Returns:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `channel` x `num frames` x `height` x `width`.
+ """
+ c, t, h, w = frames.shape
+ frames = frames.view(c, t * h, w)
+ frames = self.color_jitter(frames) # pyre-ignore[6,9]
+ frames = frames.view(c, t, h, w)
+
+ return frames
+
+
+class GaussianBlur(object):
+ """
+ A PIL image version of Gaussian blur augmentation as
+ in SimCLR https://arxiv.org/abs/2002.05709
+
+ Args:
+ sigma (list[float]): A list of 2 floats with in which
+ the blur radius is randomly sampled during each step.
+ """
+
+ def __init__(self, sigma: List[float] = (0.1, 2.0)) -> None:
+ self.sigma = sigma
+
+ def __call__(self, img: Image.Image) -> Image.Image:
+ """
+ img (Image): A PIL image with single or 3 color channels.
+ """
+ sigma = self.sigma[0]
+ if len(self.sigma) == 2:
+ sigma = random.uniform(self.sigma[0], self.sigma[1])
+
+ img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
+ return img
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/__init__.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d9a460caa9753a72236d6014047b91ae1d2f305
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from .byol import BYOLModule # noqa
+from .moco_v2 import MOCOV2Module # noqa
+from .simclr import SimCLRModule # noqa
+from .video_classification import VideoClassificationModule # noqa
+
+
+__all__ = [
+ "VideoClassificationModule",
+ "SimCLRModule",
+ "BYOLModule",
+ "MOCOV2Module",
+]
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4d726fdf389acd0eacef537ba0e80c51ce0083b
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py
@@ -0,0 +1,329 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from dataclasses import dataclass
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from hydra.core.config_store import ConfigStore
+from omegaconf import MISSING
+from pytorchvideo.models.resnet import create_resnet
+from pytorchvideo.models.weight_init import init_net_weights
+from pytorchvideo_trainer.module.ssl_helper import create_mlp_util, SSLBaseModule
+from pytorchvideo_trainer.module.video_classification import (
+ Batch,
+ BatchKey,
+ EnsembleMethod,
+)
+from torchrecipes.core.conf import ModuleConf
+from torchrecipes.utils.config_utils import get_class_name_str
+
+
+class BYOL(nn.Module):
+ """
+ Bootstrap Your Own Latent A New Approach to Self-Supervised Learning
+ Details can be found in:
+ https://arxiv.org/pdf/2006.07733.pdf
+ """
+
+ def __init__(
+ self,
+ mmt: float,
+ backbone: nn.Module,
+ predictor: nn.Module,
+ backbone_mmt: nn.Module,
+ projector: Optional[nn.Module] = None,
+ projector_mmt: Optional[nn.Module] = None,
+ ) -> None:
+ """
+ Args:
+ backbone (nn.Module): backbone for byol, input shape depends on the forward
+ input size. Standard inputs include `B x C`, `B x C x H x W`, and
+ `B x C x T x H x W`.
+ projector (nn.Module): An mlp with 2 to 3 hidden layers,
+ with (synchronized) BatchNorm and ReLU activation.
+ backbone_mmt (nn.Module): backbone for byol, input shape depends on the forward
+ input size. Standard inputs include `B x C`, `B x C x H x W`, and
+ `B x C x T x H x W`.
+ projector_mmt (nn.Module): Am mlp with 2 to 3 hidden layers,
+ with (synchronized) BatchNorm and ReLU activation.
+ predictor (nn.Module): predictor MLP of BYOL of similar structure as the
+ projector MLP.
+ mmt (float): momentum update ratio for the momentum backbone.
+ """
+ super().__init__()
+
+ self.mmt: float = mmt
+ if projector is not None:
+ backbone = nn.Sequential(
+ backbone,
+ projector,
+ )
+ init_net_weights(backbone)
+ self.backbone = backbone
+
+ if projector_mmt is not None:
+ backbone_mmt = nn.Sequential(
+ backbone_mmt,
+ projector_mmt,
+ )
+ init_net_weights(backbone_mmt)
+ self.backbone_mmt = backbone_mmt
+
+ for p in self.backbone_mmt.parameters():
+ p.requires_grad = False
+
+ init_net_weights(predictor)
+ self.predictor = predictor
+
+ self._copy_weights_to_backbone_mmt()
+
+ def _copy_weights_to_backbone_mmt(self) -> None:
+ dist = {}
+ for name, p in self.backbone.named_parameters():
+ dist[name] = p
+ for name, p in self.backbone_mmt.named_parameters():
+ p.data.copy_(dist[name].data)
+
+ @torch.no_grad()
+ def momentum_update_backbone(self) -> None:
+ """
+ Momentum update on the backbone.
+ """
+ m = self.mmt
+ dist = {}
+ for name, p in self.backbone.named_parameters():
+ dist[name] = p
+ for name, p in self.backbone_mmt.named_parameters():
+ # pyre-ignore[41]
+ p.data = dist[name].data * (1.0 - m) + p.data * m
+
+ @torch.no_grad()
+ def forward_backbone_mmt(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward momentum backbone.
+ Args:
+ x (tensor): input to be forwarded of shape N x C x T x H x W
+ """
+ with torch.no_grad():
+ proj = self.backbone_mmt(x)
+ return F.normalize(proj, dim=1)
+
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
+ """
+ Args:
+ x (tensor): input to be forwarded of shape N x C x T x H x W
+ """
+ if not self.training:
+ x = self.backbone(x)
+ x = F.normalize(x, dim=1)
+ return x
+
+ proj = self.backbone(x)
+ pred = self.predictor(proj)
+ pred = F.normalize(pred, dim=1)
+
+ out_proj = F.normalize(proj, dim=1)
+
+ return out_proj, pred # pyre-ignore[7]
+
+
+def create_byol_resnet_50(
+ # Backbone
+ backbone_creator: Callable = create_resnet, # pyre-ignore[24]
+ backbone_embed_dim: int = 128,
+ head_pool: Callable = nn.AdaptiveAvgPool3d, # pyre-ignore[24]
+ head_output_size: Tuple[int, int, int] = (1, 1, 1),
+ head_activation: Callable = None, # pyre-ignore[9,24]
+ dropout_rate: float = 0.0,
+ # Projector
+ projector_dim_in: int = 2048,
+ projector_inner_dim: int = 4096,
+ projector_depth: int = 2,
+ # Predictor
+ predictor_inner_dim: int = 4096,
+ predictor_depth: int = 2,
+ predictor_norm: Callable = nn.BatchNorm1d, # pyre-ignore[24]
+ projector_norm: Callable = nn.BatchNorm1d, # pyre-ignore[24]
+ mmt: float = 0.99,
+) -> BYOL:
+ """
+ Builds a Resnet video backbone, projector and predictors models for
+ BYOL SSL task.
+ """
+
+ def _make_bacbone_and_projector(): # pyre-ignore[3]
+ backbone = backbone_creator(
+ dropout_rate=dropout_rate,
+ head_activation=head_activation,
+ head_output_with_global_average=True,
+ head_pool=head_pool,
+ head_output_size=head_output_size,
+ )
+
+ backbone.blocks[-1].proj = None # Overwite head projection
+ projector = create_mlp_util(
+ projector_dim_in,
+ backbone_embed_dim,
+ projector_inner_dim,
+ projector_depth,
+ norm=projector_norm,
+ )
+ return backbone, projector
+
+ backbone, projector = _make_bacbone_and_projector()
+ backbone_mmt, projector_mmt = _make_bacbone_and_projector()
+
+ predictor = create_mlp_util(
+ backbone_embed_dim,
+ backbone_embed_dim,
+ predictor_inner_dim,
+ predictor_depth,
+ norm=predictor_norm,
+ )
+ byol_model = BYOL(
+ mmt=mmt,
+ backbone=backbone,
+ projector=projector,
+ predictor=predictor,
+ backbone_mmt=backbone_mmt,
+ projector_mmt=projector_mmt,
+ )
+ return byol_model
+
+
+class BYOLModule(SSLBaseModule):
+ """
+ The Lightning Base module for BYOL SSL video task.
+
+ For more details refer to,
+ 1. Bootstrap your own latent: A new approach to self-supervised Learning:
+ https://arxiv.org/abs/2006.07733
+ 2. A Large-Scale Study on Unsupervised Spatiotemporal Representation Learning
+
+ Args:
+ model (OmegaConf): An omega conf object intializing the neural-network modle.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/model`
+ loss(OmegaConf): An omega conf object intializing the loss function.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/loss`
+ optim (OmegaConf): An omega conf object for constructing the optimizer object.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.optimizer.OptimizerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/optim`
+ metrics (OmegaConf): The metrics to track, which will be used for both train,
+ validation and test. Example configs can be found in
+ `pytorchvideo_trainer/conf/module/metricx`
+ lr_scheduler (OmegaConf): An omega conf object associated with learning rate
+ scheduler used during trainer.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler`
+ modality_key (str): The modality key used in data processing, default: "video".
+ ensemble_method (str): The data ensembling method to control how we accumulate
+ the testing results at video level, which is optional. Users may choose from
+ ["sum", "max", None], If it is set to None, no data ensembling will be applied.
+ knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN
+ Memory module to use. Example config can be found at,
+ `pytorchvideo_trainer/conf/module/knn_memory`.
+ momentum_anneal_cosine (bool): For MoCo and BYOL tasks, if set to true, cosine
+ anneals the momentum term used from updating the backbone-history model.
+ num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if
+ pytorch lightning trainer's sync_batchnorm parameter is to false.
+ """
+
+ def __init__(
+ self,
+ model: Any, # pyre-ignore[2]
+ loss: Any, # pyre-ignore[2]
+ optim: Any, # pyre-ignore[2]
+ metrics: List[Any], # pyre-ignore[2]
+ lr_scheduler: Optional[Any] = None, # pyre-ignore[2]
+ modality_key: BatchKey = "video",
+ ensemble_method: Optional[EnsembleMethod] = None,
+ knn_memory: Optional[Any] = None, # pyre-ignore[2]
+ momentum_anneal_cosine: bool = False,
+ num_sync_devices: int = 1,
+ ) -> None:
+ super().__init__(
+ model=model,
+ loss=loss,
+ optim=optim,
+ metrics=metrics,
+ lr_scheduler=lr_scheduler,
+ modality_key=modality_key,
+ ensemble_method=ensemble_method,
+ knn_memory=knn_memory,
+ momentum_anneal_cosine=momentum_anneal_cosine,
+ num_sync_devices=num_sync_devices,
+ )
+
+ def training_step(
+ self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
+ ) -> None:
+ self.cur_epoch_step += 1 # pyre-ignore[16]
+
+ if self.momentum_anneal_cosine:
+ self._cosine_anneal_momentum()
+
+ self.manual_zero_opt_grad()
+ self.manual_update_lr()
+
+ inputs = batch[self.modality_key] # pyre-ignore[6]
+
+ self.model.momentum_update_backbone() # pyre-ignore[29]
+ keys = self._compute_keys(inputs)
+
+ partial_loss = 0.0
+ for k, vids in enumerate(inputs):
+ other_keys = keys[:k] + keys[k + 1 :]
+ assert len(other_keys) > 0, "Length of keys cannot be zero"
+
+ proj, pred = self.model(vids)
+ loss_k = self.loss(pred, other_keys[0])
+ for i in range(1, len(other_keys)):
+ loss_k += self.loss(pred, other_keys[i])
+ loss_k /= len(other_keys)
+
+ self.manual_backward(loss_k)
+ partial_loss += loss_k.detach()
+
+ if self.knn_memory is not None:
+ self.knn_memory.update(proj, batch["video_index"]) # pyre-ignore[29,61]
+
+ partial_loss /= len(inputs) * 2.0 # to have same loss as symmetric loss
+ self.log("Losses/train_loss", partial_loss, on_step=True, on_epoch=True)
+
+ self.manual_opt_step()
+
+ @torch.no_grad()
+ def _compute_keys(self, x: torch.Tensor) -> List[torch.Tensor]:
+ keys = []
+ for sub_x in x:
+ # pyre-ignore[29]
+ keys.append(self.model.forward_backbone_mmt(sub_x).detach())
+ return keys
+
+
+@dataclass
+class BYOLModuleConf(ModuleConf):
+ _target_: str = get_class_name_str(BYOLModule)
+ model: Any = MISSING # pyre-ignore[4]
+ loss: Any = MISSING # pyre-ignore[4]
+ optim: Any = MISSING # pyre-ignore[4]
+ metrics: List[Any] = MISSING # pyre-ignore[4]
+ lr_scheduler: Optional[Any] = None # pyre-ignore[4]
+ modality_key: str = "video"
+ ensemble_method: Optional[str] = None
+ num_sync_devices: Optional[int] = 1
+ knn_memory: Optional[Any] = None # pyre-ignore[4]
+ momentum_anneal_cosine: bool = False
+
+
+cs = ConfigStore()
+cs.store(
+ group="schema/module",
+ name="byol_module_conf",
+ node=BYOLModuleConf,
+ package="module",
+)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/distributed_utils.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/distributed_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..78549a15a7410fa1dd5ec82a5f232be00e7d302c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/distributed_utils.py
@@ -0,0 +1,331 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""Distributed helpers."""
+import functools
+import logging
+import pickle
+from typing import Any, List, Optional, Tuple, TypeVar
+
+import torch
+import torch.distributed as dist
+
+
+DistProcessGroup = TypeVar("ProcessGroup")
+
+
+def all_gather(tensors: List[torch.Tensor]) -> List[torch.Tensor]:
+ """
+ All gathers the provided tensors from all processes across machines.
+
+ Args:
+ tensors (list): tensors to perform all gather across all processes in
+ all machines.
+ """
+
+ gather_list = []
+ output_tensor = []
+ world_size = dist.get_world_size()
+ for tensor in tensors:
+ tensor_placeholder = [torch.ones_like(tensor) for _ in range(world_size)]
+ dist.all_gather(tensor_placeholder, tensor, async_op=False)
+ gather_list.append(tensor_placeholder)
+ for gathered_tensor in gather_list:
+ output_tensor.append(torch.cat(gathered_tensor, dim=0))
+ return output_tensor
+
+
+def cat_all_gather(
+ tensors: torch.Tensor, process_group: Optional[DistProcessGroup] = None
+) -> torch.Tensor:
+ """
+ Performs the concatenated all_reduce operation on the provided tensors.
+ """
+ if process_group is not None:
+ gather_sz = get_process_group_size(process_group)
+ else:
+ gather_sz = dist.get_world_size()
+ tensors_gather = [torch.ones_like(tensors) for _ in range(gather_sz)]
+ dist.all_gather(
+ tensors_gather,
+ tensors,
+ async_op=False,
+ group=process_group,
+ )
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+def all_reduce(tensors: List[torch.Tensor], average: bool = True) -> List[torch.Tensor]:
+ """
+ All reduce the provided tensors from all processes across machines.
+
+ Args:
+ tensors (list): tensors to perform all reduce across all processes in
+ all machines.
+ average (bool): scales the reduced tensor by the number of overall
+ processes across all machines.
+ """
+
+ for tensor in tensors:
+ dist.all_reduce(tensor, async_op=False)
+ if average:
+ world_size = dist.get_world_size()
+ for tensor in tensors:
+ tensor.mul_(1.0 / world_size)
+ return tensors
+
+
+def init_process_group(
+ local_rank: int,
+ local_world_size: int,
+ shard_id: int,
+ num_shards: int,
+ init_method: str,
+ dist_backend: str = "nccl",
+) -> None:
+ """
+ Initializes the default process group.
+
+ Args:
+ local_rank (int): the rank on the current local machine.
+ local_world_size (int): the world size (number of processes running) on
+ the current local machine.
+ shard_id (int): the shard index (machine rank) of the current machine.
+ num_shards (int): number of shards for distributed training.
+ init_method (string): supporting three different methods for
+ initializing process groups:
+ "file": use shared file system to initialize the groups across
+ different processes.
+ "tcp": use tcp address to initialize the groups across different
+ dist_backend (string): backend to use for distributed training. Options
+ includes gloo, mpi and nccl, the details can be found here:
+ https://pytorch.org/docs/stable/distributed.html
+ """
+ # Sets the GPU to use.
+ torch.cuda.set_device(local_rank)
+ # Initialize the process group.
+ proc_rank = local_rank + shard_id * local_world_size
+ world_size = local_world_size * num_shards
+ dist.init_process_group(
+ backend=dist_backend,
+ init_method=init_method,
+ world_size=world_size,
+ rank=proc_rank,
+ )
+
+
+def get_world_size() -> int:
+ """
+ Get the size of the world.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ """
+ Get the rank of the current process.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def synchronize() -> None:
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group() -> List[int]:
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+
+ Returns:
+ (group): pytorch dist group.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+# pyre-ignore [2]
+def _serialize_to_tensor(data: Any, group: List[int]) -> torch.Tensor:
+ """
+ Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
+ backend is supported.
+
+ Args:
+ data (data): data to be serialized.
+ group (group): pytorch dist group.
+ Returns:
+ tensor (ByteTensor): tensor that serialized.
+ """
+
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024**3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024**3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(
+ tensor: torch.Tensor, group: List[int]
+) -> Tuple[List[int], torch.Tensor]:
+ """
+ Padding all the tensors from different GPUs to the largest ones.
+
+ Args:
+ tensor (tensor): tensor to pad.
+ group (group): pytorch dist group.
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device)
+ for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros(
+ (max_size - local_size,), dtype=torch.uint8, device=tensor.device
+ )
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+# pyre-ignore [2,3]
+def all_gather_unaligned(data: Any, group: Optional[List[int]] = None) -> List[Any]:
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
+ for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def get_process_group_size(process_group: DistProcessGroup) -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=process_group)
+
+
+def get_local_rank(process_group: DistProcessGroup) -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+
+ return dist.get_rank(group=process_group)
+
+
+class AllGatherWithGradient(torch.autograd.Function):
+ """
+ Support distributed all_gather for any arbitrary tensor while
+ preserving its gradient.
+ """
+
+ @staticmethod
+ # pyre-ignore [2,14]
+ def forward(ctx: Any, input: torch.Tensor) -> torch.Tensor:
+ world_size = dist.get_world_size()
+ x_gather = [torch.ones_like(input) for _ in range(world_size)]
+ dist.all_gather(x_gather, input, async_op=False)
+ x_gather = torch.cat(x_gather, dim=0)
+ return x_gather
+
+ @staticmethod
+ # pyre-ignore [2,14]
+ def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
+
+ reduction = dist.all_reduce(grad_output, async_op=True)
+ reduction.wait()
+
+ world_size = dist.get_world_size()
+ N = grad_output.size(0)
+ mini_batchsize = N // world_size
+ cur_gpu = dist.get_rank()
+ grad_output = grad_output[
+ cur_gpu * mini_batchsize : (cur_gpu + 1) * mini_batchsize
+ ]
+ return grad_output
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/losses.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..d371307266382045b12b8585ebb20d38fd566e4c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/losses.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from typing import List
+
+import pytorchvideo_trainer.module.distributed_utils as du
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pytorchvideo.layers.utils import set_attributes
+
+
+class SoftTargetCrossEntropy(nn.Module):
+ """
+ Cross entropy loss with soft target.
+ """
+
+ def __init__(self, reduction: str = "mean") -> None:
+ """
+ Args:
+ reduction (str): specifies reduction to apply to the output.
+ It can be "mean" (default) or "none".
+ """
+ super(SoftTargetCrossEntropy, self).__init__()
+ self.reduction = reduction
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+
+ loss = torch.sum(-y * F.log_softmax(x, dim=-1), dim=-1)
+ if self.reduction == "mean":
+ return loss.mean()
+ elif self.reduction == "none":
+ return loss
+ else:
+ raise NotImplementedError
+
+
+class NtxentLoss(nn.Module):
+ """
+ NT-Xent loss for SimCLR Self-Supervised learning approach -
+ https://arxiv.org/abs/2002.05709
+
+ Args:
+ temperature (float): scalar value to scale the loss by.
+ """
+
+ def __init__(
+ self,
+ temperature: float,
+ ) -> None:
+ super().__init__()
+ set_attributes(self, locals()) # pyre-ignore[6]
+
+ def forward(self, x_list: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Args:
+ x_list (list[torch.tensor]): A list of two tensors of shape N x C.
+ Where, N is the batch size and C is the SSL model's embedding size.
+ """
+ assert (
+ len(x_list) == 2
+ ), f"Invalid list input to SimCLR. Expected dimention 2 but received {len(x_list)}"
+
+ out_1, out_2 = x_list
+
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ out_1 = du.AllGatherWithGradient.apply(out_1) # pyre-ignore[16]
+ out_2 = du.AllGatherWithGradient.apply(out_2)
+ out = torch.cat([out_1, out_2], dim=0)
+ # [2*B, 2*B]
+ sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / self.temperature)
+ mask = (
+ torch.ones_like(sim_matrix)
+ - torch.eye(out.shape[0], device=sim_matrix.device)
+ ).bool()
+ # [2*B, 2*B-1]
+ sim_matrix = sim_matrix.masked_select(mask).view(out.shape[0], -1)
+ # compute loss
+ pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature)
+ # [2*B]
+ pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
+ loss = (-torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()
+
+ return loss
+
+
+class SimilarityLoss(nn.Module):
+ """
+ Temperature-scaled Similarity loss for BYOL Self-Supervised learning
+ approach - https://arxiv.org/abs/2006.07733
+
+ Args:
+ temperature (float): scalar value to scale the loss by.
+ """
+
+ def __init__(self, temperature: float) -> None:
+ super().__init__()
+ self.temperature = temperature
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ q and k (nn.tensor): inputs to calculate the similarity, expected to have
+ the same shape of `N x C`. Where N is the batch size and C
+ is the SSL model's embedding size.
+ """
+ similarity = torch.einsum("nc,nc->n", [q, k])
+ similarity /= self.temperature
+ loss = -similarity.mean()
+ return loss
+
+
+class ContrastiveLoss(nn.Module):
+ """
+ Temperature-scaled Contrastive loss for MoCo and other Self-Supervised learning
+ approaches - https://arxiv.org/abs/1911.05722
+
+ Args:
+ temperature (float): scalar value to scale the loss by.
+ """
+
+ def __init__(self, reduction: str = "mean", temperature: float = 0.1) -> None:
+ super(ContrastiveLoss, self).__init__()
+ self.reduction = reduction
+ self.temperature = temperature
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ inputs (nn.tensor): Expected to have the same shape of `N x C`.
+ Where, N is the batch size and C is the SSL model's embedding size.
+ """
+ inputs = torch.div(inputs, self.temperature)
+ targets = torch.zeros(inputs.shape[0], dtype=torch.long).to(inputs.device)
+ loss = nn.CrossEntropyLoss(reduction=self.reduction).cuda()(inputs, targets)
+ return loss
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/lr_policy.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/lr_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ef9f5d3411b7df335df20fef0c16c366b12cf1
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/lr_policy.py
@@ -0,0 +1,156 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""Learning rate policy."""
+import math
+from dataclasses import dataclass
+from typing import Callable, List
+
+import torch
+from hydra.core.config_store import ConfigStore
+from omegaconf import MISSING
+
+
+@dataclass
+class LRSchedulerConf:
+ # common
+ lr_policy: str = MISSING
+ lr: float = MISSING
+ max_iters: int = MISSING
+ warmup_iters: int = MISSING
+ warmup_start_lr: float = MISSING
+
+ # cosine
+ cosine_end_lr: float = MISSING
+ cosine_after_warmup: bool = MISSING
+
+ # LRS
+ steps: List[int] = MISSING
+ lrs: List[float] = MISSING
+
+
+cs = ConfigStore()
+cs.store(
+ group="schema/module/lr_scheduler",
+ name="lr_scheduler_conf",
+ node=LRSchedulerConf,
+ package="module.lr_scheduler",
+)
+
+
+def get_lr_at_epoch(cfg: LRSchedulerConf, cur_epoch: float) -> float:
+ """
+ Retrieve the learning rate of the current epoch with the option to perform
+ warm up in the beginning of the training stage.
+
+ Args:
+ cfg (LRSchedulerConf): Hydra / omega conf object associated with
+ Learningrate scheduler. The schema can be found in
+ `LRSchedulerConf` and the example configs can be found in
+ `pytorchvideo_trainer/conf/module/lr_scheduler`.
+ cur_epoch (float): the number of epoch of the current training stage.
+ """
+ lr = get_lr_func(cfg.lr_policy)(cfg, cur_epoch)
+ # Perform warm up.
+ if cur_epoch < cfg.warmup_iters:
+ lr_start = cfg.warmup_start_lr
+ lr_end = get_lr_func(cfg.lr_policy)(cfg, cfg.warmup_iters)
+ alpha = (lr_end - lr_start) / cfg.warmup_iters
+ lr = cur_epoch * alpha + lr_start
+ return lr
+
+
+def lr_func_cosine(cfg: LRSchedulerConf, cur_epoch: float) -> float:
+ """
+ Retrieve the learning rate to specified values at specified epoch with the
+ cosine learning rate schedule. Details can be found in:
+ Ilya Loshchilov, and Frank Hutter ,SGDR: Stochastic Gradient
+ Descent With Warm Restarts.
+
+ Args:
+ cfg (CfgNode): Hydra / omega conf object associated with
+ Learningrate scheduler. The schema can be found in
+ `LRSchedulerConf` and the example configs can be found in
+ `pytorchvideo_trainer/conf/module/lr_scheduler`.
+ cur_epoch (float): the number of epoch of the current training stage.
+ """
+ offset = cfg.warmup_iters if cfg.cosine_after_warmup else 0.0
+ assert cfg.cosine_end_lr < cfg.lr
+ return (
+ cfg.cosine_end_lr
+ + (cfg.lr - cfg.cosine_end_lr)
+ * (math.cos(math.pi * (cur_epoch - offset) / (cfg.max_iters - offset)) + 1.0)
+ * 0.5
+ )
+
+
+def lr_func_steps_with_relative_lrs(cfg: LRSchedulerConf, cur_epoch: float) -> float:
+ """
+ Retrieve the learning rate to specified values at specified epoch with the
+ steps with relative learning rate schedule.
+
+ Args:
+ cfg (CfgNode): configs. Hydra / omega conf object associated with
+ Learningrate scheduler. The schema can be found in
+ `LRSchedulerConf` and the example configs can be found in
+ `pytorchvideo_trainer/conf/module/lr_scheduler`.
+ cur_epoch (float): the number of epoch of the current training stage.
+ """
+ ind = get_step_index(cfg, cur_epoch)
+ return cfg.lrs[ind] * cfg.lr
+
+
+def get_step_index(cfg: LRSchedulerConf, cur_epoch: float) -> int:
+ """
+ Retrieves the lr step index for the given epoch.
+
+ Args:
+ cfg (CfgNode): Hydra / omega conf object associated with
+ Learningrate scheduler. The schema can be found in
+ `LRSchedulerConf` and the example configs can be found in
+ `pytorchvideo_trainer/conf/module/lr_scheduler`.
+ cur_epoch (float): the number of epoch of the current training stage.
+ """
+ steps = cfg.steps + [cfg.max_iters]
+ for ind, step in enumerate(steps): # NoQA
+ if cur_epoch < step:
+ break
+ return ind - 1
+
+
+def get_lr_func(lr_policy: str) -> Callable: # pyre-ignore[24]
+ """
+ Given the configs, retrieve the specified lr policy function.
+
+ Args:
+ lr_policy (string): the learning rate policy to use for the job.
+ """
+ policy = "lr_func_" + lr_policy
+ if policy not in globals():
+ raise NotImplementedError("Unknown LR policy: {}".format(lr_policy))
+ else:
+ return globals()[policy]
+
+
+def get_epoch_lr(cur_epoch: float, cfg: LRSchedulerConf) -> float:
+ """
+ Retrieves the lr for the given epoch (as specified by the lr policy).
+
+ Args:
+ cfg (config): Hydra / omega conf object associated with
+ Learningrate scheduler. The schema can be found in
+ `LRSchedulerConf` and the example configs can be found in
+ `pytorchvideo_trainer/conf/module/lr_scheduler`.
+ cur_epoch (float): the number of epoch of the current training stage.
+ """
+ return get_lr_at_epoch(cfg, cur_epoch)
+
+
+def set_lr(optimizer: torch.optim.Optimizer, new_lr: float) -> None:
+ """
+ Sets the optimizer lr to the specified value.
+ Args:
+ optimizer (optim): the optimizer using to optimize the current network.
+ new_lr (float): the new learning rate to set.
+ """
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = new_lr
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdda0b6d3819c2192fdf3ae0026fb9fd4fb18379
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py
@@ -0,0 +1,456 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import pytorchvideo_trainer.module.distributed_utils as du
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from hydra.core.config_store import ConfigStore
+from omegaconf import MISSING
+from pytorchvideo.models.resnet import create_resnet
+from pytorchvideo.models.weight_init import init_net_weights
+from pytorchvideo_trainer.module.ssl_helper import create_mlp_util, SSLBaseModule
+from pytorchvideo_trainer.module.video_classification import (
+ Batch,
+ BatchKey,
+ EnsembleMethod,
+)
+from torchrecipes.core.conf import ModuleConf
+from torchrecipes.utils.config_utils import get_class_name_str
+
+
+def create_moco_resnet_50(
+ # Backbone
+ backbone_creator: Callable = create_resnet, # pyre-ignore[24]
+ backbone_embed_dim: int = 128,
+ head_pool: Callable = nn.AdaptiveAvgPool3d, # pyre-ignore[24]
+ head_output_size: Tuple[int, int, int] = (1, 1, 1),
+ head_activation: Callable = None, # pyre-ignore[9,24]
+ dropout_rate: float = 0.0,
+ # Projector
+ projector_dim_in: int = 2048,
+ projector_inner_dim: int = 2048,
+ projector_depth: int = 3,
+ projector_norm: Optional[Callable] = None, # pyre-ignore[24]
+ mmt: float = 0.994,
+) -> nn.Module:
+ def _make_bacbone_and_projector(): # pyre-ignore[3]
+ backbone = backbone_creator(
+ dropout_rate=dropout_rate,
+ head_activation=head_activation,
+ head_output_with_global_average=True,
+ head_pool=head_pool,
+ head_output_size=head_output_size,
+ stem_conv_kernel_size=(1, 7, 7),
+ head_pool_kernel_size=(8, 7, 7),
+ )
+
+ backbone.blocks[-1].proj = None # Overwite head projection
+ projector = create_mlp_util(
+ projector_dim_in,
+ backbone_embed_dim,
+ projector_inner_dim,
+ projector_depth,
+ norm=projector_norm, # pyre-ignore[6]
+ )
+ return backbone, projector
+
+ backbone, projector = _make_bacbone_and_projector()
+ backbone_mmt, projector_mmt = _make_bacbone_and_projector()
+
+ moco_model = MOCO(
+ mmt=mmt,
+ backbone=backbone,
+ projector=projector,
+ backbone_mmt=backbone_mmt,
+ projector_mmt=projector_mmt,
+ )
+ return moco_model
+
+
+class MOCO(nn.Module):
+ """
+ Momentum Contrast for unsupervised Visual Representation Learning
+ Details can be found in:
+ https://arxiv.org/abs/1911.05722
+ """
+
+ def __init__(
+ self,
+ mmt: float,
+ backbone: nn.Module,
+ backbone_mmt: nn.Module,
+ projector: Optional[nn.Module] = None,
+ projector_mmt: Optional[nn.Module] = None,
+ ) -> None:
+ """
+ Args:
+ backbone (nn.Module): backbone for byol, input shape depends on the forward
+ input size. Standard inputs include `B x C`, `B x C x H x W`, and
+ `B x C x T x H x W`.
+ projector (nn.Module): An mlp with 2 to 3 hidden layers,
+ with (synchronized) BatchNorm and ReLU activation.
+ backbone_mmt (nn.Module): backbone for byol, input shape depends on the forward
+ input size. Standard inputs include `B x C`, `B x C x H x W`, and
+ `B x C x T x H x W`.
+ projector_mmt (nn.Module): Am mlp with 2 to 3 hidden layers,
+ with (synchronized) BatchNorm and ReLU activation.
+ mmt (float): momentum update ratio for the momentum backbone.
+ """
+ super().__init__()
+
+ self.mmt: float = mmt
+
+ if projector is not None:
+ backbone = nn.Sequential(
+ backbone,
+ projector,
+ )
+ init_net_weights(backbone)
+ self.backbone = backbone
+
+ if projector_mmt is not None:
+ backbone_mmt = nn.Sequential(
+ backbone_mmt,
+ projector_mmt,
+ )
+ init_net_weights(backbone_mmt)
+ self.backbone_mmt = backbone_mmt
+
+ for p in self.backbone_mmt.parameters():
+ p.requires_grad = False
+
+ self._copy_weights_to_backbone_mmt()
+
+ def _copy_weights_to_backbone_mmt(self) -> None:
+ dist = {}
+ for name, p in self.backbone.named_parameters():
+ dist[name] = p
+ for name, p in self.backbone_mmt.named_parameters():
+ p.data.copy_(dist[name].data)
+
+ @torch.no_grad()
+ def momentum_update_backbone(self) -> None:
+ """
+ Momentum update on the backbone.
+ """
+ m = self.mmt
+ dist = {}
+ for name, p in self.backbone.named_parameters():
+ dist[name] = p
+ for name, p in self.backbone_mmt.named_parameters():
+ # pyre-ignore[41]
+ p.data = dist[name].data * (1.0 - m) + p.data * m
+
+ @torch.no_grad()
+ def forward_backbone_mmt(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward momentum backbone.
+ Args:
+ x (tensor): input to be forwarded of shape N x C x T x H x W
+ """
+ with torch.no_grad():
+ proj = self.backbone_mmt(x)
+ out_proj = F.normalize(proj, dim=1)
+ return out_proj
+
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
+ """
+ Args:
+ x (tensor): input to be forwarded of shape N x C x T x H x W
+ """
+ proj = self.backbone(x)
+ out_proj = F.normalize(proj, dim=1)
+ return out_proj
+
+
+class MOCOV2Module(SSLBaseModule):
+ """
+ The Lightning Base module for MoCo SSL video task.
+
+ For more details refer to,
+ 1. Momentum Contrast for unsupervised Visual Representation Learning:
+ https://arxiv.org/abs/1911.05722
+ 2. A Large-Scale Study on Unsupervised Spatiotemporal Representation Learning
+
+ Args:
+ model (OmegaConf): An omega conf object intializing the neural-network modle.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/model`
+ loss(OmegaConf): An omega conf object intializing the loss function.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/loss`
+ optim (OmegaConf): An omega conf object for constructing the optimizer object.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.optimizer.OptimizerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/optim`
+ metrics (OmegaConf): The metrics to track, which will be used for both train,
+ validation and test. Example configs can be found in
+ `pytorchvideo_trainer/conf/module/metricx`
+ dim (int): Dimentionality of features in the stored queue. Set to be same as
+ embedding dimentions for the SSL model.
+ k (int): Queue size for stored features.
+ batch_suffle (bool): If true, performs shuffling of the computed keys.
+ local_shuffle_bn (bool): If true, only performs shuffling of keys with in the
+ current node.
+ lr_scheduler (OmegaConf): An omega conf object associated with learning rate
+ scheduler used during trainer.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler`
+ modality_key (str): The modality key used in data processing, default: "video".
+ ensemble_method (str): The data ensembling method to control how we accumulate
+ the testing results at video level, which is optional. Users may choose from
+ ["sum", "max", None], If it is set to None, no data ensembling will be applied.
+ knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN
+ Memory module to use. Example config can be found at,
+ `pytorchvideo_trainer/conf/module/knn_memory`.
+ momentum_anneal_cosine (bool): For MoCo and BYOL tasks, if set to true, cosine
+ anneals the momentum term used from updating the backbone-history model.
+ num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if
+ pytorch lightning trainer's sync_batchnorm parameter is to false.
+ """
+
+ def __init__(
+ self,
+ model: Any, # pyre-ignore[2]
+ loss: Any, # pyre-ignore[2]
+ optim: Any, # pyre-ignore[2]
+ metrics: List[Any], # pyre-ignore[2]
+ dim: int,
+ k: int,
+ batch_shuffle: bool,
+ local_shuffle_bn: bool,
+ lr_scheduler: Optional[Any] = None, # pyre-ignore[2]
+ modality_key: BatchKey = "video",
+ ensemble_method: Optional[EnsembleMethod] = None,
+ knn_memory: Optional[Any] = None, # pyre-ignore[2]
+ momentum_anneal_cosine: bool = False,
+ num_sync_devices: int = 1,
+ ) -> None:
+ super().__init__(
+ model=model,
+ loss=loss,
+ optim=optim,
+ metrics=metrics,
+ lr_scheduler=lr_scheduler,
+ modality_key=modality_key,
+ ensemble_method=ensemble_method,
+ knn_memory=knn_memory,
+ momentum_anneal_cosine=momentum_anneal_cosine,
+ num_sync_devices=num_sync_devices,
+ )
+
+ self.dim: int = dim
+ self.k: int = k
+ self.batch_shuffle_on = batch_shuffle
+ self.local_shuffle_bn = local_shuffle_bn
+ self.register_buffer("ptr", torch.tensor([0]))
+ self.ptr.requires_grad = False
+ stdv = 1.0 / math.sqrt(self.dim / 3)
+ self.register_buffer(
+ "queue_x",
+ torch.rand(self.k, self.dim).mul_(2 * stdv).add_(-stdv),
+ )
+ self.queue_x.requires_grad = False
+ self.local_process_group = None # pyre-ignore[4]
+
+ def on_fit_start(self) -> None:
+ """Called at the very beginning of fit.
+ If on DDP it is called on every process
+ """
+ dataloader = self.trainer.datamodule.train_dataloader()
+ if self.knn_memory is not None:
+ self.knn_memory.init_knn_labels(dataloader) # pyre-ignore[29]
+
+ world_size = self.trainer.world_size
+ if (
+ torch.distributed.is_available()
+ and torch.distributed.is_initialized()
+ and self.local_shuffle_bn
+ and self.batch_shuffle_on
+ ):
+ self._create_local_process_group()
+
+ # TODO: For ad's dataloder this might be different
+ # pyre-ignore[16]
+ self.no_update_iters = self.k // world_size // dataloader.batch_size
+
+ def _create_local_process_group(self) -> None:
+ assert self.trainer.num_gpus > 1, "Error creating local process group in MoCo"
+
+ for i in range(self.trainer.num_nodes):
+ ranks_on_i = list(
+ range(i * self.trainer.num_gpus, (i + 1) * self.trainer.num_gpus)
+ )
+ pg = torch.distributed.new_group(ranks=ranks_on_i)
+ if i == torch.distributed.get_rank() // self.trainer.num_gpus:
+ self.local_process_group = pg
+
+ def training_step(
+ self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
+ ) -> None:
+
+ self.cur_epoch_step += 1 # pyre-ignore[16]
+
+ if self.momentum_anneal_cosine:
+ self._cosine_anneal_momentum()
+
+ self.manual_zero_opt_grad()
+ self.manual_update_lr()
+
+ inputs = batch[self.modality_key] # pyre-ignore[6]
+
+ self.model.momentum_update_backbone() # pyre-ignore[29]
+ keys = self._compute_keys(inputs)
+
+ partial_loss = 0.0
+ for k, vids in enumerate(inputs):
+ other_keys = keys[:k] + keys[k + 1 :]
+ assert len(other_keys) > 0, "Length of keys cannot be zero"
+
+ proj = self.model(vids)
+ q_knn = proj
+ queue_neg = torch.einsum("nc,kc->nk", [proj, self.queue_x.clone().detach()])
+
+ for k, key in enumerate(other_keys):
+ out_pos = torch.einsum("nc,nc->n", [proj, key]).unsqueeze(-1)
+ lgt_k = torch.cat([out_pos, queue_neg], dim=1)
+ if k == 0:
+ logits = lgt_k
+ else:
+ logits = torch.cat([logits, lgt_k], dim=0)
+ loss_k = self.loss(logits) # pyre-ignore[61]
+
+ self.manual_backward(loss_k)
+ partial_loss += loss_k.detach()
+
+ if self.knn_memory is not None:
+ self.knn_memory.update(q_knn, batch["video_index"]) # pyre-ignore[29,61]
+
+ partial_loss /= len(inputs) * 2.0 # to have same loss as symmetric loss
+ self.log("Losses/train_loss", partial_loss, on_step=True, on_epoch=True)
+ self._dequeue_and_enqueue(keys)
+
+ if (
+ self.trainer.current_epoch == 0
+ and self.cur_epoch_step < self.no_update_iters
+ ):
+ print(
+ f"No update: Epoch {self.trainer.current_epoch}"
+ + f" Step {self.cur_epoch_step}/{self.no_update_iters}"
+ )
+ return
+
+ self.manual_opt_step()
+
+ @torch.no_grad()
+ def _compute_keys(self, x: torch.Tensor) -> List[torch.Tensor]:
+ keys = []
+ for sub_x in x:
+ if self.batch_shuffle_on:
+ with torch.no_grad():
+ sub_x, idx_restore = self._batch_shuffle(sub_x)
+ with torch.no_grad():
+ # pyre-ignore[29]
+ key = self.model.forward_backbone_mmt(sub_x).detach()
+
+ if self.batch_shuffle_on:
+ key = self._batch_unshuffle(key, idx_restore).detach()
+ keys.append(key)
+ return keys
+
+ @torch.no_grad()
+ def _batch_shuffle(self, x: torch.Tensor): # pyre-ignore[3]
+ world_size = self.trainer.world_size
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ if self.local_shuffle_bn:
+ assert self.local_process_group is not None
+ x = du.cat_all_gather(x, self.local_process_group)
+ gpu_idx = du.get_local_rank(self.local_process_group)
+ world_size = self.trainer.num_gpus
+ else:
+ x = du.cat_all_gather(x)
+ gpu_idx = torch.distributed.get_rank()
+
+ idx_randperm = torch.randperm(x.shape[0]).to(self.device)
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ torch.distributed.broadcast(idx_randperm, src=0)
+ else:
+ gpu_idx = 0
+ idx_randperm = idx_randperm.view(world_size, -1)
+ x = x[idx_randperm[gpu_idx, :]] # pyre-ignore[61]
+ idx_restore = torch.argsort(idx_randperm.view(-1))
+ idx_restore = idx_restore.view(world_size, -1)
+
+ return x, idx_restore
+
+ @torch.no_grad()
+ def _batch_unshuffle(
+ self, x: torch.Tensor, idx_restore: torch.Tensor
+ ) -> torch.Tensor:
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ if self.local_shuffle_bn:
+ assert self.local_process_group is not None
+ x = du.cat_all_gather(x, self.local_process_group)
+ gpu_idx = du.get_local_rank(self.local_process_group)
+ else:
+ x = du.cat_all_gather(x)
+ gpu_idx = torch.distributed.get_rank()
+ else:
+ gpu_idx = 0
+
+ idx = idx_restore[gpu_idx, :]
+ x = x[idx]
+ return x
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(
+ self,
+ keys: List[torch.Tensor],
+ ) -> None:
+ assert len(keys) > 0, "need to have multiple views for adding them to queue"
+ ptr = int(self.ptr.item())
+ for key in keys:
+ # write the current feat into queue, at pointer
+ num_items = int(key.size(0))
+ assert (
+ self.k % num_items == 0
+ ), "Queue size should be a multiple of batchsize"
+ assert ptr + num_items <= self.k
+ self.queue_x[ptr : ptr + num_items, :] = key
+ # move pointer
+ ptr += num_items
+ # reset pointer
+ if ptr == self.k:
+ ptr = 0
+ self.ptr[0] = ptr
+
+
+@dataclass
+class MOCOV2ModuleConf(ModuleConf):
+ _target_: str = get_class_name_str(MOCOV2Module)
+ model: Any = MISSING # pyre-ignore[4]
+ loss: Any = MISSING # pyre-ignore[4]
+ optim: Any = MISSING # pyre-ignore[4]
+ metrics: List[Any] = MISSING # pyre-ignore[4]
+ lr_scheduler: Optional[Any] = None # pyre-ignore[4]
+ modality_key: str = "video"
+ ensemble_method: Optional[str] = None
+ num_sync_devices: Optional[int] = 1
+ knn_memory: Optional[Any] = None # pyre-ignore[4]
+ momentum_anneal_cosine: bool = False
+ dim: int = MISSING
+ k: int = MISSING
+ batch_shuffle: bool = MISSING
+ local_shuffle_bn: bool = MISSING
+
+
+cs = ConfigStore()
+cs.store(
+ group="schema/module",
+ name="moco_v2_module_conf",
+ node=MOCOV2ModuleConf,
+ package="module",
+)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/optimizer.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de8eb2fdea56ac0557cda61108cfa1f5cf7246a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/optimizer.py
@@ -0,0 +1,257 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+# pyre-ignore-all-errors
+
+from dataclasses import dataclass
+
+import torch
+from hydra.core.config_store import ConfigStore
+from omegaconf import MISSING
+
+
+@dataclass
+class OptimizerConf:
+ method: str = MISSING
+ lr: float = MISSING
+ weight_decay: float = 1e-4
+ bn_weight_decay: float = 0.0
+ momentum: float = 0.9
+ dampening: float = 0.0
+ nesterov: bool = True
+ zero_weight_decay_1d_param: bool = False
+ lars_on: bool = False
+
+
+# TODO: Refactor contruct_optimer to torch.optim conf + construct_param_group
+def construct_optimizer(
+ model: torch.nn.Module, cfg: OptimizerConf # noqa
+) -> torch.optim.Optimizer:
+ """
+ Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
+ with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
+ Batchnorm and/or no-update 1-D parameters support, based on the config.
+
+ Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
+ (LARS): https://arxiv.org/abs/1708.03888
+
+ Args:
+ model (nn.Module): model to perform stochastic gradient descent
+ optimization or ADAM optimization.
+ cfg (OptimizerConf): Hydra/Omega conf object consisting hyper-parameters
+ of SGD or ADAM, includes base learning rate, momentum, weight_decay,
+ dampening and etc. The supported config schema is `OptimizerConf`.
+ Example config files can be found at,
+ `pytorchvideo_trainer/conf/module/optim`
+ """
+ bn_parameters = []
+ non_bn_parameters = []
+ zero_parameters = []
+ no_grad_parameters = []
+ skip = {}
+
+ if hasattr(model, "no_weight_decay"):
+ skip = model.no_weight_decay() # pyre-ignore[29]
+
+ for name, m in model.named_modules():
+ is_bn = isinstance(m, torch.nn.modules.batchnorm._NormBase)
+ for p in m.parameters(recurse=False):
+ if not p.requires_grad:
+ no_grad_parameters.append(p)
+ elif is_bn:
+ bn_parameters.append(p)
+ elif name in skip:
+ zero_parameters.append(p)
+ elif cfg.zero_weight_decay_1d_param and (
+ len(p.shape) == 1 or name.endswith(".bias")
+ ):
+ zero_parameters.append(p)
+ else:
+ non_bn_parameters.append(p)
+
+ optim_params = [
+ {
+ "params": bn_parameters,
+ "weight_decay": cfg.bn_weight_decay,
+ "apply_LARS": False,
+ },
+ {
+ "params": non_bn_parameters,
+ "weight_decay": cfg.weight_decay,
+ "apply_LARS": cfg.lars_on,
+ },
+ {
+ "params": zero_parameters,
+ "weight_decay": 0.0,
+ "apply_LARS": cfg.lars_on,
+ },
+ ]
+ optim_params = [x for x in optim_params if len(x["params"])] # pyre-ignore[6]
+
+ # Check all parameters will be passed into optimizer.
+ assert len(list(model.parameters())) == len(non_bn_parameters) + len(
+ bn_parameters
+ ) + len(zero_parameters) + len(
+ no_grad_parameters
+ ), "parameter size does not match: {} + {} + {} + {} != {}".format(
+ len(non_bn_parameters),
+ len(bn_parameters),
+ len(zero_parameters),
+ len(no_grad_parameters),
+ len(list(model.parameters())),
+ )
+ print(
+ "bn {}, non bn {}, zero {} no grad {}".format(
+ len(bn_parameters),
+ len(non_bn_parameters),
+ len(zero_parameters),
+ len(no_grad_parameters),
+ )
+ )
+
+ if cfg.method == "sgd":
+ optimizer = torch.optim.SGD(
+ optim_params,
+ lr=cfg.lr,
+ momentum=cfg.momentum,
+ weight_decay=cfg.weight_decay,
+ dampening=cfg.dampening,
+ nesterov=cfg.nesterov,
+ )
+ elif cfg.method == "adam":
+ optimizer = torch.optim.Adam(
+ optim_params,
+ lr=cfg.lr,
+ betas=(0.9, 0.999),
+ weight_decay=cfg.weight_decay,
+ )
+ elif cfg.method == "adamw":
+ optimizer = torch.optim.AdamW(
+ optim_params,
+ lr=cfg.lr,
+ eps=1e-08,
+ weight_decay=cfg.weight_decay,
+ )
+ else:
+ raise NotImplementedError("Does not support {} optimizer".format(cfg.method))
+
+ if cfg.lars_on:
+ optimizer = LARS(optimizer=optimizer, trust_coefficient=0.001, clip=False)
+ return optimizer
+
+
+cs = ConfigStore()
+cs.store(
+ group="schema/module/optim",
+ name="optim_conf",
+ node=OptimizerConf,
+ package="module.optim",
+)
+
+
+class LARS(torch.optim.Optimizer):
+ """
+ This class is adapted from
+ https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py to
+ include ignoring LARS application specific parameters (e.g. 1D params)
+
+ Args:
+ optimizer (torch.optim): Pytorch optimizer to wrap and modify learning rate for.
+ trust_coefficient: Trust coefficient for calculating the lr.
+ See https://arxiv.org/abs/1708.03888
+ clip (bool): Decides between clipping or scaling mode of LARS. If `clip=True` the
+ learning rate is set to `min(optimizer_lr, local_lr)` for each parameter.
+ If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
+ eps (float): epsilon kludge to help with numerical stability while calculating
+ adaptive_lr.
+ ignore_1d_param (float): If true, does not update 1 dimentional parameters.
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ trust_coefficient=0.02,
+ clip=True,
+ eps=1e-8,
+ ignore_1d_param=True,
+ ) -> None:
+ self.optim = optimizer
+ self.trust_coefficient = trust_coefficient
+ self.eps = eps
+ self.clip = clip
+ self.ignore_1d_param = ignore_1d_param
+
+ self.defaults = self.optim.defaults
+
+ def __getstate__(self):
+ return self.optim.__getstate__()
+
+ def __setstate__(self, state):
+ self.optim.__setstate__(state)
+
+ @property
+ def state(self):
+ return self.optim.state
+
+ def __repr__(self):
+ return self.optim.__repr__()
+
+ @property
+ def param_groups(self):
+ return self.optim.param_groups
+
+ @param_groups.setter
+ def param_groups(self, value):
+ self.optim.param_groups = value
+
+ def state_dict(self):
+ return self.optim.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self.optim.load_state_dict(state_dict)
+
+ def zero_grad(self):
+ self.optim.zero_grad()
+
+ def add_param_group(self, param_group):
+ self.optim.add_param_group(param_group)
+
+ def step(self, closure=None):
+ with torch.no_grad():
+ weight_decays = []
+ for group in self.optim.param_groups:
+ # absorb weight decay control from optimizer
+ weight_decay = group["weight_decay"] if "weight_decay" in group else 0
+ weight_decays.append(weight_decay)
+ apply_LARS = group["apply_LARS"] if "apply_LARS" in group else True
+ if not apply_LARS:
+ continue
+ group["weight_decay"] = 0
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ if self.ignore_1d_param and p.ndim == 1: # ignore bias
+ continue
+ param_norm = torch.norm(p.data)
+ grad_norm = torch.norm(p.grad.data)
+
+ if param_norm != 0 and grad_norm != 0:
+ # calculate adaptive lr + weight decay
+ adaptive_lr = (
+ self.trust_coefficient
+ * (param_norm)
+ / (grad_norm + param_norm * weight_decay + self.eps)
+ )
+
+ # clip learning rate for LARS
+ if self.clip:
+ # calculation of adaptive_lr so that when multiplied
+ # by lr it equals `min(adaptive_lr, lr)`
+ adaptive_lr = min(adaptive_lr / group["lr"], 1)
+
+ p.grad.data += weight_decay * p.data
+ p.grad.data *= adaptive_lr
+
+ self.optim.step()
+ # return weight decay control to optimizer
+ for i, group in enumerate(self.optim.param_groups):
+ group["weight_decay"] = weight_decays[i]
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/simclr.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..f22ebf98902e99c89bba4bc65b53b3bb0c8c62ab
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/simclr.py
@@ -0,0 +1,229 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from dataclasses import dataclass
+from typing import Any, Callable, List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from hydra.core.config_store import ConfigStore
+from omegaconf import MISSING
+from pytorchvideo.models.resnet import create_resnet
+from pytorchvideo.models.weight_init import init_net_weights
+from pytorchvideo_trainer.module.byol import create_mlp_util
+from pytorchvideo_trainer.module.ssl_helper import SSLBaseModule
+from pytorchvideo_trainer.module.video_classification import (
+ Batch,
+ BatchKey,
+ EnsembleMethod,
+)
+from torchrecipes.core.conf import ModuleConf
+from torchrecipes.utils.config_utils import get_class_name_str
+
+
+class SimCLR(nn.Module):
+ """
+ Skeletal NN.Module for the SimCLR model that supports
+ arbitrary bacbone and projector models.
+ """
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ projector: Optional[nn.Module] = None,
+ ) -> None:
+ """
+ Args:
+ backbone (nn.Module): backbone for simclr, input shape depends on the forward
+ input size. Standard inputs include `B x C`, `B x C x H x W`, and
+ `B x C x T x H x W`.
+ projector (nn.Module): An mlp with 2 to 3 hidden layers,
+ with (synchronized) BatchNorm and ReLU activation.
+ """
+ super().__init__()
+
+ if projector is not None:
+ backbone = nn.Sequential(
+ backbone,
+ projector,
+ )
+ init_net_weights(backbone)
+ self.backbone = backbone
+
+ def forward(
+ self, x_list: Union[torch.Tensor, List[torch.Tensor]]
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """
+ Args:
+ x_list (list(tensor) or tensor): Expects a list of 2 tensors
+ for trainin phase and single tensor for the train and val
+ phases. Here all tensors are expected to be of the shape,
+ N x C x T x H x W.
+ """
+ if not self.training:
+ assert isinstance(
+ x_list, torch.Tensor
+ ), "Expected tensor for test/val phase in SimCLR"
+ if self.backbone is not None:
+ x_list = self.backbone(x_list)
+ x_list = F.normalize(x_list, p=2, dim=1)
+ return x_list
+
+ assert (
+ isinstance(x_list, list) and len(x_list) == 2
+ ), f"Invalid list input to SimCLR. Expected len 2 but received {len(x_list)}"
+
+ for i, x in enumerate(x_list):
+ if self.backbone is not None:
+ x = self.backbone(x)
+ x = F.normalize(x, p=2, dim=1)
+ x_list[i] = x
+
+ return x_list
+
+
+def create_simclr_resnet_50(
+ # Backbone
+ backbone_creator: Callable = create_resnet, # pyre-ignore[24]
+ backbone_embed_dim: int = 128,
+ dim_in: int = 2048,
+ # Projector
+ # TODO: Standardize projector conf across all SSL tasks
+ mlp_activation: Callable = nn.ReLU, # pyre-ignore[24]
+ mlp_inner_dim: int = 2048,
+ mlp_depth: int = 1,
+ mlp_norm: Optional[Callable] = None, # pyre-ignore[24]
+) -> SimCLR:
+ """
+ Builds a Resnet video model with a projector for SimCLR
+ SSL traning task.
+ """
+ backbone = backbone_creator(
+ model_num_class=backbone_embed_dim,
+ dropout_rate=0.0,
+ )
+ backbone.blocks[-1].proj = None
+ projector = create_mlp_util(
+ dim_in,
+ backbone_embed_dim,
+ mlp_inner_dim,
+ mlp_depth,
+ norm=mlp_norm, # pyre-ignore[6]
+ )
+ simclr = SimCLR(
+ backbone=backbone,
+ projector=projector,
+ )
+ return simclr
+
+
+class SimCLRModule(SSLBaseModule):
+ """
+ The Lightning Base module for SimCLR SSL video task.
+
+ For more details refer to,
+ 1. A Simple Framework for Contrastive Learning of Visual Representations :
+ https://arxiv.org/abs/2002.05709
+ 2. A Large-Scale Study on Unsupervised Spatiotemporal Representation Learning
+
+ Args:
+ model (OmegaConf): An omega conf object intializing the neural-network modle.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/model`
+ loss(OmegaConf): An omega conf object intializing the loss function.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/loss`
+ optim (OmegaConf): An omega conf object for constructing the optimizer object.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.optimizer.OptimizerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/optim`
+ metrics (OmegaConf): The metrics to track, which will be used for both train,
+ validation and test. Example configs can be found in
+ `pytorchvideo_trainer/conf/module/metricx`
+ lr_scheduler (OmegaConf): An omega conf object associated with learning rate
+ scheduler used during trainer.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler`
+ modality_key (str): The modality key used in data processing, default: "video".
+ ensemble_method (str): The data ensembling method to control how we accumulate
+ the testing results at video level, which is optional. Users may choose from
+ ["sum", "max", None], If it is set to None, no data ensembling will be applied.
+ knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN
+ Memory module to use. Example config can be found at,
+ `pytorchvideo_trainer/conf/module/knn_memory`.
+ num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if
+ pytorch lightning trainer's sync_batchnorm parameter is to false.
+ """
+
+ def __init__(
+ self,
+ model: Any, # pyre-ignore[2]
+ loss: Any, # pyre-ignore[2]
+ optim: Any, # pyre-ignore[2]
+ metrics: List[Any], # pyre-ignore[2]
+ lr_scheduler: Optional[Any] = None, # pyre-ignore[2]
+ modality_key: BatchKey = "video",
+ ensemble_method: Optional[EnsembleMethod] = None,
+ knn_memory: Optional[Any] = None, # pyre-ignore[2]
+ num_sync_devices: int = 1,
+ ) -> None:
+ super().__init__(
+ model=model,
+ loss=loss,
+ optim=optim,
+ metrics=metrics,
+ lr_scheduler=lr_scheduler,
+ modality_key=modality_key,
+ ensemble_method=ensemble_method,
+ knn_memory=knn_memory,
+ momentum_anneal_cosine=False,
+ num_sync_devices=num_sync_devices,
+ )
+
+ def training_step(
+ self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
+ ) -> None:
+
+ self.cur_epoch_step += 1 # pyre-ignore[16]
+
+ self.manual_zero_opt_grad()
+ self.manual_update_lr()
+
+ inputs = batch[self.modality_key] # pyre-ignore[6]
+ partial_loss = 0.0
+ for i in range(len(inputs) - 1):
+ y_hat = self(inputs[i : i + 2])
+ loss = self.loss(y_hat)
+ self.manual_backward(loss)
+ partial_loss += loss.detach()
+
+ partial_loss /= len(inputs) - 1
+ self.log("Losses/train_loss", partial_loss, on_step=True, on_epoch=True)
+
+ if self.knn_memory is not None:
+ # pyre-ignore[29]
+ self.knn_memory.update(y_hat[0], batch["video_index"])
+
+ self.manual_opt_step()
+
+
+@dataclass
+class SimCLRModuleConf(ModuleConf):
+ _target_: str = get_class_name_str(SimCLRModule)
+ model: Any = MISSING # pyre-ignore[4]
+ loss: Any = MISSING # pyre-ignore[4]
+ optim: Any = MISSING # pyre-ignore[4]
+ metrics: List[Any] = MISSING # pyre-ignore[4]
+ lr_scheduler: Optional[Any] = None # pyre-ignore[4]
+ modality_key: str = "video"
+ ensemble_method: Optional[str] = None
+ num_sync_devices: Optional[int] = 1
+ knn_memory: Optional[Any] = None # pyre-ignore[4]
+
+
+cs = ConfigStore()
+cs.store(
+ group="schema/module",
+ name="simclr_module_conf",
+ node=SimCLRModuleConf,
+ package="module",
+)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/ssl_helper.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/ssl_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f63d2df5c40e851843480b68deb5642bac67a944
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/ssl_helper.py
@@ -0,0 +1,473 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import math
+from typing import Any, Callable, Dict, List, Optional
+
+import numpy as np
+import pytorchvideo_trainer.module.distributed_utils as du
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from hydra.utils import instantiate
+from pytorch_lightning.trainer import Trainer
+from pytorchvideo_trainer.module.video_classification import (
+ Batch,
+ BatchKey,
+ EnsembleMethod,
+ VideoClassificationModule,
+)
+
+
+def create_mlp_util(
+ dim_in: int,
+ dim_out: int,
+ mlp_dim: int,
+ num_layers: int,
+ norm: Callable, # pyre-ignore[24]
+ bias: bool = True,
+ xavier_init: bool = True,
+) -> nn.Module:
+ """
+ A utility method for creating the MLP that gets attached to the SSL
+ bacbone network either in the form of the projector or predictor.
+
+ Consists of multiple squences of "Linear -> Norm -> Relu" layers.
+
+ Args:
+ dim_in (int): Input dimension size to the MLP.
+ dim_out (int): Output dimension size of MLP.
+ mlp_dim (int): Dimentions size for the inner layers of MLP.
+ num_layers (int): Number of layer in the MLP.
+ norm (callabe): Type of normalization to apply between layers.
+ Examples include BatchNorm, SyncBatchNorm, etc
+ bias (bool): If set true, enables bias for the final layer.
+ xavier_init (bool): If set to true, performs Xavier weight
+ initialization for all linear layers.
+ """
+ if num_layers == 1:
+ return nn.Linear(dim_in, dim_out)
+
+ b = False if norm is not None else bias
+ mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)]
+ mlp_layers[-1].xavier_init = xavier_init
+ for i in range(1, num_layers):
+ if norm:
+ mlp_layers.append(norm(mlp_dim))
+ mlp_layers.append(nn.ReLU(inplace=True))
+ if i == num_layers - 1:
+ d = dim_out
+ b = bias
+ else:
+ d = mlp_dim
+ mlp_layers.append(nn.Linear(mlp_dim, d, bias=b))
+ mlp_layers[-1].xavier_init = xavier_init
+ return nn.Sequential(*mlp_layers)
+
+
+def create_classification_model_from_ssl_checkpoint(
+ ssl_checkpoint_path: str,
+ checkpoint_type: str,
+ mlp: Optional[nn.Module] = None,
+ detach_backbone: bool = False,
+) -> nn.Module:
+
+ """
+ A utlity function for extracting the bacbone from the PyTorch Lightning's
+ SSL checkpoints. Used for supervided finetuning the SSL pre-trained models
+ in video classification task.
+
+ Extracts bacbone from the checkpoints of the SimCLR, BYOL and MoCoV2 SSL
+ tasks and attaches the given MLP to the backbone.
+
+ Args:
+ ssl_checkpoint_path (str): Path to the lightning checkpoint for the
+ said SSL task.
+ checkpoint_type (str): Type of the SSL task the checkpoint belongs to.
+ Should be one of ["simclr, "byol", "mocov_v2"]
+ mlp (nn.Module): If specified, the MLP module to attach to the bacbone
+ for the supervised finetuning phase.
+ detach_bacbone: If true, detaches bacbone and no gradient are tracked and
+ updated for the bacbone. Only updates the MLP weights during finetuning.
+
+ Returns:
+ model (SSLFineTuningModel): Returns an instance of `SSLFineTuningModel`,
+ consisting of bacbone and mlp.
+ """
+
+ if checkpoint_type == "simclr":
+ from pytorchvideo_trainer.module.simclr import SimCLRModule as M
+
+ lightning_module = M.load_from_checkpoint(ssl_checkpoint_path)
+ backbone = lightning_module.model.backbone[0]
+ elif checkpoint_type == "byol":
+ from pytorchvideo_trainer.module.byol import BYOLModule as M
+
+ lightning_module = M.load_from_checkpoint(ssl_checkpoint_path)
+ backbone = lightning_module.model.backbone[0]
+ elif checkpoint_type == "moco_v2":
+ from pytorchvideo_trainer.module.moco_v2 import MOCOV2Module as M
+
+ lightning_module = M.load_from_checkpoint(ssl_checkpoint_path)
+ backbone = lightning_module.model.backbone[0]
+ else:
+ raise ValueError("Incorrect SSL checkpoint type.")
+
+ # pyre-ignore[6]
+ return SSLFineTuningModel(backbone, mlp, detach_backbone)
+
+
+class SSLFineTuningModel(nn.Module):
+ """
+ Model consisting of a backbone sequentially followed by an an MLP.
+ Used for supervised finetuning of the SSL pre-trained models.
+
+ Args:
+ backbone (nn.Module): A model whole weights are conditionally
+ updated based on the betach_backbone parameter.
+ mlp (nn.Module): If specified, the MLP module to attach to the bacbone
+ for the supervised finetuning phase.
+ detach_bacbone: If true, detaches bacbone and no gradient are tracked and
+ updated for the bacbone. Only updates the MLP weights during finetuning.
+ """
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ mlp: nn.Module,
+ detach_backbone: bool,
+ ) -> None:
+ super().__init__()
+
+ self.backbone = backbone
+ self.mlp = mlp
+ self.detach_backbone = detach_backbone
+
+ for p in self.backbone.parameters():
+ p.requires_grad = False if detach_backbone else True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.backbone(x)
+ if self.detach_backbone:
+ x = x.detach()
+ if self.mlp is not None:
+ x = self.mlp(x)
+ return x
+
+
+class KnnMemory(nn.Module):
+ """
+ KNN Memory object that keeps track of the features generated by the SSL model
+ during the traing phase and performs nearest neighbours inference during the
+ test and validation phases for video classfication.
+
+ KNN memory requires that you provide the labels and video indices for the
+ dataset used for the SSL training phase.
+
+ Args:
+ length (int): Size of the KNN memory. Set to be equal to the training dataset size.
+ dim (int): Feture dimention generated by the SSL model.
+ momentum (float): The rate at which to update the features in memory during the SSL-
+ training phase.
+ downstream_classes (int): Number of classes in the dataset.
+ temperature (float): Temperature scaling to use during the inference phase. Typically,
+ set to the same value as the loss temperature used in SSL.
+ knn_k (int): Number of nearest neighbours to aggregate metrics over for inference.
+ deive (str): Device to store the memory module on.
+ """
+
+ def __init__(
+ self,
+ length: int,
+ dim: int,
+ momentum: float = 1.0,
+ downstream_classes: int = 400,
+ temperature: float = 1.0,
+ knn_k: int = 200,
+ device: str = "cpu",
+ ) -> None:
+ super(KnnMemory, self).__init__()
+ self.length = length
+ self.dim = dim
+ self.momentum = momentum
+ self.temperature = temperature
+ self.downstream_classes = downstream_classes
+ self.knn_k = knn_k
+ stdv = 1.0 / math.sqrt(dim / 3)
+ self.device = device
+ self.register_buffer(
+ "memory",
+ torch.rand(length, dim, device=self.device).mul_(2 * stdv).add_(-stdv),
+ )
+
+ def resize(self, length: int, dim: int) -> None:
+ """
+ Resizes the memory and intialized it fresh.
+
+ Args:
+ length (int): Size of the KNN memory. Set to be equal to the training
+ dataset size.
+ dim (int): Feture dimention generated by the SSL model.
+ """
+ self.length = length
+ self.dim = dim
+ stdv = 1.0 / math.sqrt(dim / 3)
+ del self.memory
+ self.memory = (
+ torch.rand(length, dim, device=self.device).mul_(2 * stdv).add_(-stdv)
+ )
+
+ @torch.no_grad()
+ def get(self, ind: torch.Tensor) -> torch.Tensor:
+ """
+ Fetches features from the memory based on the video index.
+
+ Args:
+ ind (int): Index of the video / clip for which to fetch the features.
+ """
+ batch_size = ind.size(0)
+ selected_mem = self.memory[ind.view(-1), :]
+ out = selected_mem.view(batch_size, -1, self.dim)
+ return out
+
+ @torch.no_grad()
+ def update(self, mem: torch.Tensor, ind: torch.Tensor) -> None:
+ """
+ Peforms feature update in the memory based on the new features realized by the
+ SSL model. Called during the SSL training phase.
+
+ Args:
+ mem (tensor): Features of the same N x C genereated by the SSL model.
+ N is the batch size and C is the feature dimention generated by the
+ SSL Model.
+ ind (tensor): A 1-D tensor of video indices associated the given features.
+ """
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ mem, ind = du.all_gather([mem, ind])
+ mem = mem.view(mem.size(0), 1, -1)
+ mem_old = self.get(ind).to(mem.device)
+
+ mem_update = mem * self.momentum + mem_old * (1 - self.momentum)
+ mem_update = F.normalize(mem_update, p=2, dim=1)
+ self.memory[ind.view(-1), :] = mem_update.squeeze().to(self.memory.device)
+
+ @torch.no_grad()
+ def init_knn_labels(self, train_loader: Trainer) -> None:
+ """
+ Called before traning, intializes the KNN Memory and resizes it based on the
+ labels and number of samples in the train dataloader.
+
+ Args:
+ train_loader (dataloader): Trainining dataloader containing an attribute
+ `dataset._labeled_videos` which holds mapping from video indices to
+ labels.
+ """
+ # TODO: Make sure all dataloader's have this property `dataset._labeled_videos`
+ self.num_imgs = len(train_loader.dataset._labeled_videos) # pyre-ignore[16]
+ # pyre-ignore[16]
+ self.train_labels = np.zeros((self.num_imgs,), dtype=np.int32)
+ for i in range(self.num_imgs): # pyre-ignore[6]
+ # pyre-ignore[29]
+ self.train_labels[i] = train_loader.dataset._labeled_videos[i][1]["label"]
+ self.train_labels = torch.LongTensor(self.train_labels).to(self.device)
+ if self.length != self.num_imgs:
+ self.resize(self.num_imgs, self.dim) # pyre-ignore[6]
+
+ def forward(self, inputs: torch.Tensor) -> None:
+ pass
+
+ @torch.no_grad()
+ def eval_knn(self, q_knn: torch.Tensor) -> torch.Tensor:
+ """
+ Peforms KNN nearest neighbour aggregations and returns predictions
+ for the qurried features.
+
+ Args:
+ q_nn (tensor): Features generated by the SSL model during the inference
+ phase. Expected to be of shape N x C where, N is the batch size and
+ C is the feature dimention generated by the SSL Model.
+ """
+ device = q_knn.device
+ batch_size = q_knn.size(0)
+ dist = torch.einsum(
+ "nc,mc->nm",
+ q_knn.view(batch_size, -1),
+ self.memory.view(self.memory.size(0), -1).to(device),
+ )
+ yd, yi = dist.topk(self.knn_k, dim=1, largest=True, sorted=True)
+
+ K = yi.shape[1]
+ C = self.downstream_classes
+ candidates = self.train_labels.view(1, -1).expand(batch_size, -1)
+ candidates = candidates.to(device)
+ yi = yi.to(device)
+ retrieval = torch.gather(candidates, 1, yi)
+ retrieval_one_hot = torch.zeros((batch_size * K, C)).to(device)
+ retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
+ yd_transform = (yd.clone().div_(self.temperature).exp_()).to(device)
+ probs = torch.mul(
+ retrieval_one_hot.view(batch_size, -1, C),
+ yd_transform.view(batch_size, -1, 1),
+ )
+ preds = torch.sum(probs, 1)
+ return preds
+
+
+class SSLBaseModule(VideoClassificationModule):
+ """
+ The Lightning Base module supporting SimCLR, MoCo and BYOL SSL tasks.
+
+ Args:
+ model (OmegaConf): An omega conf object intializing the neural-network modle.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/model`
+ loss(OmegaConf): An omega conf object intializing the loss function.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/loss`
+ optim (OmegaConf): An omega conf object for constructing the optimizer object.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.optimizer.OptimizerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/optim`
+ metrics (OmegaConf): The metrics to track, which will be used for both train,
+ validation and test. Example configs can be found in
+ `pytorchvideo_trainer/conf/module/metricx`
+ lr_scheduler (OmegaConf): An omega conf object associated with learning rate
+ scheduler used during trainer.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler`
+ modality_key (str): The modality key used in data processing, default: "video".
+ ensemble_method (str): The data ensembling method to control how we accumulate
+ the testing results at video level, which is optional. Users may choose from
+ ["sum", "max", None], If it is set to None, no data ensembling will be applied.
+ knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN
+ Memory module to use. Example config can be found at,
+ `pytorchvideo_trainer/conf/module/knn_memory`.
+ momentum_anneal_cosine (bool): For MoCo and BYOL tasks, if set to true, cosine
+ anneals the momentum term used from updating the backbone-history model.
+ num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if
+ pytorch lightning trainer's sync_batchnorm parameter is to false.
+ """
+
+ def __init__(
+ self,
+ model: Any, # pyre-ignore[2]
+ loss: Any, # pyre-ignore[2]
+ optim: Any, # pyre-ignore[2]
+ metrics: List[Any], # pyre-ignore[2]
+ lr_scheduler: Optional[Any] = None, # pyre-ignore[2]
+ modality_key: BatchKey = "video",
+ ensemble_method: Optional[EnsembleMethod] = None,
+ knn_memory: Optional[Any] = None, # pyre-ignore[2]
+ momentum_anneal_cosine: bool = False, # TODO: Refactor out mmt from base class.
+ num_sync_devices: int = 1,
+ ) -> None:
+ super().__init__(
+ model=model,
+ loss=loss,
+ optim=optim,
+ metrics=metrics,
+ lr_scheduler=lr_scheduler,
+ modality_key=modality_key,
+ ensemble_method=ensemble_method,
+ num_sync_devices=num_sync_devices,
+ )
+
+ self.knn_memory: nn.Module = instantiate(knn_memory)
+ self.automatic_optimization = False
+ self.momentum_anneal_cosine = momentum_anneal_cosine
+ if self.momentum_anneal_cosine:
+ self.initial_mmt: float = self.model.mmt # pyre-ignore[8]
+
+ if ensemble_method is not None:
+ assert (
+ self.knn_memory is not None
+ ), "Test-Ensembling is only supported with KNN module"
+
+ def on_fit_start(self) -> None:
+ """
+ Called at the very beginning of fit.
+ If on DDP it is called on every process.
+
+ Peforms conversion of model batchnorm layers into syncbatchnom
+ and intialized the KNN module using the dataloader labels.
+ """
+
+ self._convert_to_sync_bn()
+ if self.knn_memory is not None:
+ dataloader = self.trainer.datamodule.train_dataloader()
+ self.knn_memory.init_knn_labels(dataloader) # pyre-ignore[29]
+
+ def _test_step_with_data_ensembling(self, batch: Batch, batch_idx: int) -> None:
+ """
+ Operates on a single batch of data from the test set.
+ """
+ assert (
+ isinstance(batch, dict)
+ and self.modality_key in batch
+ and "label" in batch
+ and "video_index" in batch
+ and self.knn_memory is not None
+ ), (
+ f"Returned batch [{batch}] is not a map with '{self.modality_key}' and"
+ + "'label' and 'video_index' keys"
+ )
+
+ y_hat = self(batch[self.modality_key])
+ y_hat = (
+ self.knn_memory.eval_knn(y_hat) if self.knn_memory is not None else y_hat
+ )
+ preds = torch.nn.functional.softmax(y_hat, dim=-1)
+ labels = batch["label"]
+ video_ids = torch.tensor(batch["video_index"], device=self.device)
+
+ self._ensemble_at_video_level(preds, labels, video_ids)
+
+ def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]:
+ """
+ If KNN Memory is enabled, evaluates metrics using the labels of neighbours
+ during the validation and test phases.
+ """
+ assert (
+ isinstance(batch, dict)
+ and self.modality_key in batch
+ and ("label" in batch or self.knn_memory is None)
+ and phase_type in ["val", "test"]
+ ), (
+ f"Returned batch [{batch}] is not a map with '{self.modality_key}' and"
+ + "'label' keys"
+ )
+
+ if self.knn_memory is not None:
+ y_hat = self(batch[self.modality_key])
+ y_hat = self.knn_memory.eval_knn(y_hat)
+ pred = torch.nn.functional.softmax(y_hat, dim=-1)
+ metrics_result = self._compute_metrics(pred, batch["label"], phase_type)
+ self.log_dict(metrics_result, on_epoch=True)
+
+ def training_step(
+ self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
+ ) -> None:
+ """Missing method implemented in subsequent derived SSL task modules."""
+ pass
+
+ @torch.no_grad()
+ def _cosine_anneal_momentum(self) -> None:
+ """
+ For MoCo and BYOL tasks, if self.momentum_anneal_cosine set to true,
+ cosine anneals the momentum term used from updating the backbone-history
+ model.
+ """
+ # pyre-ignore[6]
+ exact_epoch = float(self.cur_epoch_step) / float(
+ self._num_training_steps_per_epoch()
+ )
+ exact_epoch += self.trainer.current_epoch
+ new_mmt = (
+ 1.0
+ - (1.0 - self.initial_mmt)
+ * (
+ math.cos(math.pi * float(exact_epoch) / float(self.trainer.max_epochs))
+ + 1.0
+ )
+ * 0.5
+ )
+ self.model.mmt = new_mmt # pyre-ignore[16]
+ self.log("MMT", new_mmt, on_step=True, prog_bar=True)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..d622af26833e9d7dfc4fe941bc20c8631e697c68
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py
@@ -0,0 +1,514 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+# pyre-strict
+
+from dataclasses import dataclass
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Tuple,
+ TypedDict,
+ Union,
+)
+
+import pytorch_lightning as pl
+import torch
+from hydra.core.config_store import ConfigStore
+from hydra.utils import instantiate
+from iopath.common.file_io import g_pathmgr
+
+# @manual "fbsource//third-party/pypi/omegaconf:omegaconf"
+from omegaconf import MISSING, OmegaConf
+from pytorch_lightning.utilities import rank_zero_info
+from pytorchvideo_trainer.datamodule.transforms import MixVideoBatchWrapper
+from pytorchvideo_trainer.module.lr_policy import get_epoch_lr, LRSchedulerConf, set_lr
+from pytorchvideo_trainer.module.optimizer import construct_optimizer
+from torch import nn
+from torch.optim.lr_scheduler import _LRScheduler
+from torchrecipes.core.conf import ModuleConf
+from torchrecipes.utils.config_utils import get_class_name_str
+
+
+class Batch(TypedDict):
+ """
+ PyTorchVideo batches are dictionaries containing each modality or metadata of
+ the batch collated video clips. For Kinetics it has the below keys and types.
+ """
+
+ video: torch.Tensor # (B, C, T, H, W)
+ audio: torch.Tensor # (B, S)
+ label: torch.Tensor # (B, 1)
+ video_index: List[int] # len(video_index) == B
+
+
+BatchKey = Literal["video", "audio", "label", "video_index"]
+EnsembleMethod = Literal["sum", "max"]
+
+
+class VideoClassificationModule(pl.LightningModule):
+ """
+ The Lightning module supporting the video classification task.
+
+ Args:
+ model (OmegaConf): An omega conf object intializing the neural-network modle.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/model`
+ loss(OmegaConf): An omega conf object intializing the loss function.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/loss`
+ optim (OmegaConf): An omega conf object for constructing the optimizer object.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.optimizer.OptimizerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/optim`
+ metrics (OmegaConf): The metrics to track, which will be used for both train,
+ validation and test. Example configs can be found in
+ `pytorchvideo_trainer/conf/module/metricx`
+ lr_scheduler (OmegaConf): An omega conf object associated with learning rate
+ scheduler used during trainer.
+ The associated config schema can be found at
+ `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`.
+ Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler`
+ modality_key (str): The modality key used in data processing, default: "video".
+ ensemble_method (str): The data ensembling method to control how we accumulate
+ the testing results at video level, which is optional. Users may choose from
+ ["sum", "max", None], If it is set to None, no data ensembling will be applied.
+ num_classes (int): The number of classes in the dataset.
+ num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if
+ pytorch lightning trainer's sync_batchnorm parameter is to false.
+ batch_transform (OmegaConf): An optional omega conf object, for constructing the
+ data transform method that act upon the entire mini batch. Examples include,
+ MixVideo transform, etc.
+ clip_gradient_norm (float): Performs gradient clipping if set to a positive value.
+ Since, we use Pytorch-lightning's manual optimization approach gradient clipping
+ has to be be set in the lightning module instead of the Trainer object.
+ """
+
+ def __init__(
+ self,
+ model: Any, # pyre-ignore[2]
+ loss: Any, # pyre-ignore[2]
+ optim: Any, # pyre-ignore[2]
+ metrics: List[Any], # pyre-ignore[2]
+ lr_scheduler: Optional[Any] = None, # pyre-ignore[2]
+ modality_key: BatchKey = "video",
+ ensemble_method: Optional[EnsembleMethod] = None,
+ num_classes: int = 400,
+ num_sync_devices: int = 1,
+ batch_transform: Optional[Any] = None, # pyre-ignore[2]
+ clip_gradient_norm: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.automatic_optimization = False
+
+ self.model: nn.Module = instantiate(model, _convert_="all")
+ self.loss: nn.Module = instantiate(loss)
+ self.batch_transform = instantiate(batch_transform) # pyre-ignore[4]
+ rank_zero_info(OmegaConf.to_yaml(optim))
+ self.optim: torch.optim.Optimizer = construct_optimizer(self.model, optim)
+ self.lr_scheduler_conf: LRSchedulerConf = lr_scheduler
+ self.modality_key: BatchKey = modality_key
+ self.ensemble_method: Optional[EnsembleMethod] = ensemble_method
+ self.num_classes: int = num_classes
+ self.clip_gradient_norm = clip_gradient_norm
+
+ self.metrics: Mapping[str, nn.Module] = {
+ metric_conf.name: instantiate(metric_conf.config) for metric_conf in metrics
+ }
+
+ self.train_metrics: nn.ModuleDict = nn.ModuleDict()
+ self.val_metrics: nn.ModuleDict = nn.ModuleDict()
+ self.test_metrics: nn.ModuleDict = nn.ModuleDict()
+
+ self.save_hyperparameters()
+
+ # These are used for data ensembling in the test stage.
+ self.video_preds: Dict[int, torch.Tensor] = {}
+ self.video_labels: Dict[int, torch.Tensor] = {}
+ self.video_clips_cnts: Dict[int, int] = {}
+
+ # Sync BatchNorm
+ self.num_sync_devices = num_sync_devices
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ if stage == "fit":
+ self.train_metrics.update(self.metrics)
+ self.val_metrics.update(self.metrics)
+ else:
+ self.test_metrics.update(self.metrics)
+
+ # pyre-ignore[14]: *args, **kwargs are not torchscriptable.
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward defines the prediction/inference actions.
+ """
+ return self.model(x)
+
+ def _num_training_steps_per_epoch(self) -> int:
+ """training steps per epoch inferred from datamodule and devices."""
+ dataloader = self.trainer.datamodule.train_dataloader()
+ world_size = self.trainer.world_size
+
+ # TODO: Make sure other dataloaders has this property
+ dataset_size = self.trainer.limit_train_batches
+ dataset_size *= len(dataloader.dataset._labeled_videos)
+
+ # TODO: Make sure other dataloaders has this property
+ return dataset_size // world_size // dataloader.batch_size
+
+ def manual_update_lr(self) -> None:
+ """Utility function for manually updating the optimizer learning rate"""
+
+ opt = self.optimizers()
+
+ if self.lr_scheduler_conf is not None:
+ # pyre-ignore[6]
+ exact_epoch = float(self.cur_epoch_step) / float(
+ self._num_training_steps_per_epoch()
+ )
+ exact_epoch += self.trainer.current_epoch
+ lr = get_epoch_lr(exact_epoch, self.lr_scheduler_conf)
+ self.log("LR", lr, on_step=True, prog_bar=True)
+ self.log("ExactE", exact_epoch, on_step=True, prog_bar=True)
+
+ if isinstance(opt, list):
+ for op in opt:
+ set_lr(op, lr) # pyre-ignore[6]
+ else:
+ set_lr(opt, lr) # pyre-ignore[6]
+
+ def manual_zero_opt_grad(self) -> None:
+ """Utility function for zeroing optimzer gradients"""
+ opt = self.optimizers()
+ if isinstance(opt, list):
+ for op in opt:
+ op.zero_grad() # pyre-ignore[16]
+ else:
+ opt.zero_grad()
+
+ def manual_opt_step(self) -> None:
+ """Utility function for manually stepping the optimzer"""
+ opt = self.optimizers()
+ if isinstance(opt, list):
+ for op in opt:
+ op.step()
+ else:
+ opt.step()
+
+ def training_step(
+ self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
+ ) -> None:
+ """
+ The PyTorchVideo models and transforms expect the same input shapes and
+ dictionary structure making this function just a matter of unwrapping the
+ dict and feeding it through the model/loss.
+ """
+ self.cur_epoch_step += 1 # pyre-ignore[16]
+
+ if self.batch_transform is not None:
+ batch = self.batch_transform(batch)
+
+ self.manual_zero_opt_grad()
+ self.manual_update_lr()
+
+ # Forward/backward
+ loss = self._step(batch, batch_idx, "train")
+ self.manual_backward(loss) # pyre-ignore[6]
+ if self.clip_gradient_norm > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.clip_gradient_norm
+ )
+ self.manual_opt_step()
+
+ def validation_step(
+ self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """
+ Operates on a single batch of data from the validation set.
+ """
+ return self._step(batch, batch_idx, "val")
+
+ def test_step(
+ self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
+ ) -> Optional[Dict[str, Any]]:
+ """
+ Operates on a single batch of data from the test set.
+ """
+ if self.ensemble_method:
+ self._test_step_with_data_ensembling(batch, batch_idx)
+ else:
+ return self._step(batch, batch_idx, "test")
+
+ def _test_step_with_data_ensembling(self, batch: Batch, batch_idx: int) -> None:
+ """
+ Operates on a single batch of data from the test set.
+ """
+ assert (
+ isinstance(batch, dict)
+ and self.modality_key in batch
+ and "label" in batch
+ and "video_index" in batch
+ ), (
+ f"Returned batch [{batch}] is not a map with '{self.modality_key}' and"
+ + "'label' and 'video_index' keys"
+ )
+
+ y_hat = self(batch[self.modality_key])
+ preds = torch.nn.functional.softmax(y_hat, dim=-1)
+ labels = batch["label"]
+ video_ids = torch.tensor(batch["video_index"], device=self.device)
+
+ self._ensemble_at_video_level(preds, labels, video_ids)
+
+ def on_train_epoch_start(self) -> None:
+ self._reset_metrics("train")
+ self.cur_epoch_step = 0.0 # pyre-ignore[16]
+
+ def on_validation_epoch_start(self) -> None:
+ self._reset_metrics("val")
+
+ def on_test_epoch_start(self) -> None:
+ self._reset_metrics("test")
+
+ def on_test_epoch_end(self) -> None:
+ """Pytorch-Lightning's method for aggregating test metrics at the end of epoch"""
+ if self.ensemble_method:
+ for video_id in self.video_preds:
+ self.video_preds[video_id] = (
+ self.video_preds[video_id] / self.video_clips_cnts[video_id]
+ )
+ video_preds = torch.stack(list(self.video_preds.values()), dim=0)
+ video_labels = torch.tensor(
+ list(self.video_labels.values()),
+ device=self.device,
+ )
+ metrics_result = self._compute_metrics(video_preds, video_labels, "test")
+ self.log_dict(metrics_result)
+
+ def _ensemble_at_video_level(
+ self, preds: torch.Tensor, labels: torch.Tensor, video_ids: torch.Tensor
+ ) -> None:
+ """
+ Ensemble multiple predictions of the same view together. This relies on the
+ fact that the dataloader reads multiple clips of the same video at different
+ spatial crops.
+ """
+ for i in range(preds.shape[0]):
+ vid_id = int(video_ids[i])
+ self.video_labels[vid_id] = labels[i]
+ if vid_id not in self.video_preds:
+ self.video_preds[vid_id] = torch.zeros(
+ (self.num_classes), device=self.device, dtype=preds.dtype
+ )
+ self.video_clips_cnts[vid_id] = 0
+
+ if self.ensemble_method == "sum":
+ self.video_preds[vid_id] += preds[i]
+ elif self.ensemble_method == "max":
+ self.video_preds[vid_id] = torch.max(self.video_preds[vid_id], preds[i])
+ self.video_clips_cnts[vid_id] += 1
+
+ def configure_optimizers(
+ self,
+ ) -> Union[
+ torch.optim.Optimizer,
+ Tuple[Iterable[torch.optim.Optimizer], Iterable[_LRScheduler]],
+ ]:
+ """Pytorch-Lightning's method for configuring optimizer"""
+ return self.optim
+
+ def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]:
+ assert (
+ isinstance(batch, dict) and self.modality_key in batch and "label" in batch
+ ), (
+ f"Returned batch [{batch}] is not a map with '{self.modality_key}' and"
+ + "'label' keys"
+ )
+
+ y_hat = self(batch[self.modality_key])
+ if phase_type == "train":
+ loss = self.loss(y_hat, batch["label"])
+ self.log(
+ f"Losses/{phase_type}_loss",
+ loss,
+ on_step=True,
+ on_epoch=True,
+ prog_bar=True,
+ )
+ else:
+ loss = None
+
+ ## TODO: Move MixUP transform metrics to sperate method.
+ if (
+ phase_type == "train"
+ and self.batch_transform is not None
+ and isinstance(self.batch_transform, MixVideoBatchWrapper)
+ ):
+ _top_max_k_vals, top_max_k_inds = torch.topk(
+ batch["label"], 2, dim=1, largest=True, sorted=True
+ )
+ idx_top1 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 0]
+ idx_top2 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 1]
+ y_hat = y_hat.detach()
+ y_hat[idx_top1] += y_hat[idx_top2]
+ y_hat[idx_top2] = 0.0
+ batch["label"] = top_max_k_inds[:, 0]
+
+ pred = torch.nn.functional.softmax(y_hat, dim=-1)
+ metrics_result = self._compute_metrics(pred, batch["label"], phase_type)
+ self.log_dict(metrics_result, on_epoch=True)
+
+ return loss
+
+ def _compute_metrics(
+ self, pred: torch.Tensor, label: torch.Tensor, phase_type: str
+ ) -> Dict[str, torch.Tensor]:
+ metrics_dict = getattr(self, f"{phase_type}_metrics")
+ metrics_result = {}
+ for name, metric in metrics_dict.items():
+ metrics_result[f"Metrics/{phase_type}/{name}"] = metric(pred, label)
+ return metrics_result
+
+ def _reset_metrics(self, phase_type: str) -> None:
+ metrics_dict = getattr(self, f"{phase_type}_metrics")
+ for _, metric in metrics_dict.items():
+ metric.reset()
+
+ def _convert_to_sync_bn(self) -> None:
+ """
+ Converts BatchNorm into sync-batchnorm.
+ If pytorch lightning trainer's sync_batchnorm parameter is to true,
+ performs global sync-batchnorm across all nodes and gpus. Else,
+ if perform local sync-batchnorm acroos specified number of gpus.
+ """
+ if (
+ hasattr(self.trainer.training_type_plugin, "sync_batchnorm")
+ and self.trainer.training_type_plugin.sync_batchnorm
+ ):
+ print("Using Global Synch BatchNorm.")
+ return None
+
+ if self.num_sync_devices > 1:
+ print(f"Using local Synch BatchNorm over {self.num_sync_devices} devices.")
+ pg = create_syncbn_process_group(self.num_sync_devices)
+ self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
+ self.model, process_group=pg
+ )
+
+ def on_fit_start(self) -> None:
+ """
+ Called at the very beginning of fit.
+ If on DDP it is called on every process.
+ """
+ self._convert_to_sync_bn()
+
+
+def create_syncbn_process_group(group_size: int) -> List[int]:
+ """
+ Creates process groups to be used for syncbn of a give ``group_size`` and returns
+ process group that current GPU participates in.
+
+ Args:
+ group_size (int): number of GPU's to collaborate for sync bn. group_size should
+ be >=2 else, no action is taken.
+ """
+ assert (
+ group_size > 1
+ ), f"Invalid group size {group_size} to convert to sync batchnorm."
+
+ world_size = torch.distributed.get_world_size()
+ assert world_size >= group_size
+ assert world_size % group_size == 0
+
+ group = None
+ for group_num in range(world_size // group_size):
+ group_ids = range(group_num * group_size, (group_num + 1) * group_size)
+ cur_group = torch.distributed.new_group(ranks=group_ids)
+ if torch.distributed.get_rank() // group_size == group_num:
+ group = cur_group
+ # can not drop out and return here,
+ # every process must go through creation of all subgroups
+
+ assert group is not None
+ return group
+
+
+@dataclass
+class VideoClassificationModuleConf(ModuleConf):
+ _target_: str = get_class_name_str(VideoClassificationModule)
+ model: Any = MISSING # pyre-ignore[4]
+ loss: Any = MISSING # pyre-ignore[4]
+ optim: Any = MISSING # pyre-ignore[4]
+ metrics: List[Any] = MISSING # pyre-ignore[4]
+ lr_scheduler: Optional[Any] = None # pyre-ignore[4]
+ modality_key: str = "video"
+ ensemble_method: Optional[str] = None
+ num_classes: int = 400
+ num_sync_devices: Optional[int] = 1
+
+
+@dataclass
+class VideoClassificationModuleConfVisionTransformer(VideoClassificationModuleConf):
+
+ batch_transform: Optional[Any] = None # pyre-ignore[4]
+ clip_gradient_norm: float = 0.0
+
+
+cs = ConfigStore()
+cs.store(
+ group="schema/module",
+ name="video_classification_module_conf",
+ node=VideoClassificationModuleConf,
+ package="module",
+)
+
+cs.store(
+ group="schema/module",
+ name="video_classification_module_conf_vision_transformer",
+ node=VideoClassificationModuleConfVisionTransformer,
+ package="module",
+)
+
+
+def create_classification_model_from_modelzoo(
+ checkpoint_path: str,
+ model: nn.Module,
+) -> nn.Module:
+ """
+ Builds a model from PyTorchVideo's model zoo checkpoint.
+
+ Example config for building this method can be found at -
+ `pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml`
+
+ Args:
+ checkpoint_path (str): Path the pretrained model weights.
+ model (nn.Module): Module to load the checkpoints into.
+ Returns:
+ model (nn.Module): Returns the model with pretrained weights loaded.
+ """
+
+ with g_pathmgr.open(checkpoint_path, "rb") as f:
+ checkpoint = torch.load(f, map_location="cpu")
+ state_dict = checkpoint["model_state"]
+ model.load_state_dict(state_dict)
+ return model
+
+
+def create_classification_model_from_lightning(
+ checkpoint_path: str,
+) -> nn.Module:
+ """
+ Builds a model from pytorchvideo_trainer's PytorchLightning checkpoint.
+
+ Example config for building this method can be found at -
+ `pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml`
+
+ Args:
+ checkpoint_path (str): Path the pretrained model weights.
+ Returns:
+ model (nn.Module): Returns the model with pretrained weights loaded.
+ """
+ lightning_model = VideoClassificationModule.load_from_checkpoint(checkpoint_path)
+ return lightning_model.model
diff --git a/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/train_app.py b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/train_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..a73be5e53a2e514c19f7623ce3a9e41be9c112ce
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/pytorchvideo_trainer/train_app.py
@@ -0,0 +1,301 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Union
+
+import hydra
+import numpy as np
+import submitit
+import torch
+from hydra.core.config_store import ConfigStore
+from omegaconf import MISSING, OmegaConf
+from omegaconf.dictconfig import DictConfig
+from pytorch_lightning import LightningDataModule, LightningModule
+from pytorch_lightning.callbacks import Callback, LearningRateMonitor
+from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
+from pytorchvideo_trainer.datamodule.datamodule import VideoClassificationDataModuleConf
+from pytorchvideo_trainer.module.video_classification import (
+ VideoClassificationModuleConf,
+)
+from torchrecipes.core.base_train_app import BaseTrainApp, TrainOutput
+from torchrecipes.core.conf import TrainAppConf, TrainerConf
+from torchrecipes.utils.config_utils import get_class_name_str
+
+
+class VideoClassificationTrainApp(BaseTrainApp):
+ """
+ This app is used to launch the video tasks (both Classfication and SSL).
+ Main point of entry for all training, validation and test phases.
+
+ The hydra/Omega conf schema used by the train app is as defined in
+ `VideoClassificationTrainAppConf`
+
+ Args:
+ module (OmegaConf): Hydra/Omega conf object associated with the initialization of the
+ pytorch-lightning module. Supported config schema's include,
+ 1. `pytorchvide_trainer.module.video_classification.VideoClassificationModuleConf`
+ 2. `pytorchvide_trainer.module.simclr.SimCLRModuleConf`
+ 3. `pytorchvide_trainer.module.byol.BYOLModuleConf`
+ 4. `pytorchvide_trainer.module.moco_v2.MOCOV2ModuleConf`
+ and more. Example definitions of the config can be found in
+ `pytorchvide_trainer/conf.module`
+ trainer (OmegaConf): Hydra/Omega conf object associated with the initialization of the
+ pytorch-lightning Trainer object. Supported config schema can be found in
+ `github.com/facebookresearch/recipes/blob/main/torchrecipes/core/conf/__init__.py`
+ datamodule (OmegaConf): Hydra/Omega conf object associated with the initialization of
+ the pytorch-lightning DataModule object. Supported config schema can be found at,
+ `pytorchvideo_trainer.datamodule.datamodule.VideoClassificationDataModuleConf`
+ logger (OmegaConf): Hydra/Omega conf object associated with the initialization of the
+ pytorch-lightning's tensboard logger object. Example config can be found at,
+ `pytorchvideo_trainer/conf/logger`
+ callbacks (List[OmegaConf]): Hydra/Omega conf object associated with the intialization
+ of a series of pytorch-ligtning Callbacks that act upon the lightning module. Expect
+ a list or iterable config object wherein, each element represent the hydra conf of
+ a single callback. Thus, supports loading multiple callabacks at a time. Example
+ configs can be found at `pytorchvideo_trainer/conf/callbacks`
+ submitit_conf (OmegaConf): Hydra/Omega conf to be used by the `submitit_launcher` for
+ launching the train app. Example config file can be found at,
+ `pytorchvideo_trainer/conf/submitit_conf`
+ """
+
+ def __init__(
+ self,
+ module: VideoClassificationModuleConf,
+ trainer: TrainerConf,
+ datamodule: VideoClassificationDataModuleConf,
+ logger: Any, # pyre-ignore[2]
+ callbacks: Optional[Any] = None, # pyre-ignore[2]
+ submitit_conf: Optional[Any] = None, # pyre-ignore[2]
+ ) -> None:
+
+ self.logger_conf: DictConfig = logger
+ self.callbacks_conf: DictConfig = callbacks
+ self.submitit_conf: DictConfig = submitit_conf
+ # This has to happen at last because it depends on the value above.
+ super().__init__(module, trainer, datamodule)
+
+ def get_data_module(self) -> Optional[LightningDataModule]:
+ """
+ Instantiate a LightningDataModule.
+ """
+ return hydra.utils.instantiate(
+ self.datamodule_conf,
+ _recursive_=False,
+ )
+
+ def get_lightning_module(self) -> LightningModule:
+ """
+ Instantiate a LightningModule.
+ """
+ return hydra.utils.instantiate(
+ self.module_conf,
+ _recursive_=False,
+ )
+
+ def get_callbacks(self) -> List[Callback]:
+ """
+ Creates a list of callbacks that feeds into trainer.
+ You can add additional ModelCheckpoint here too.
+ """
+ callbacks = []
+ if self.trainer_conf.logger:
+ callbacks.extend(
+ [
+ LearningRateMonitor(),
+ ]
+ )
+ if self.callbacks_conf is None:
+ return callbacks
+
+ for cb_conf in self.callbacks_conf.values():
+ callbacks.append(
+ hydra.utils.instantiate(
+ cb_conf,
+ _recursive_=False,
+ ),
+ )
+
+ return callbacks
+
+ def _make_reproducible_conf(self) -> DictConfig:
+ conf = OmegaConf.create()
+ conf._target_ = "pytorchvideo_trainer.train_app.VideoClassificationTrainApp"
+ conf.module = self.module_conf
+ conf.trainer = self.trainer_conf
+ conf.datamodule = self.datamodule_conf
+ conf.logger = self.logger_conf
+ conf.callbacks = self.callbacks_conf
+ conf.submitit_conf = self.submitit_conf
+ return conf
+
+ def get_logger(self) -> TensorBoardLogger:
+ """
+ Creates a logger that feeds into trainer.
+ Override this method to return a logger for trainer.
+ """
+ logger = hydra.utils.instantiate(
+ self.logger_conf,
+ _recursive_=False,
+ )
+
+ @rank_zero_only
+ def log_params() -> None: # pyre-ignore[53]
+ if os.environ["PTV_TRAINER_ENV"] == "oss":
+ from iopath.common.file_io import g_pathmgr
+
+ conf_to_log = self._make_reproducible_conf()
+ conf_save_path = os.path.join(logger.log_dir, "train_app_conf.yaml")
+ g_pathmgr.mkdirs(logger.log_dir)
+ if not g_pathmgr.exists(conf_save_path):
+ with g_pathmgr.open(conf_save_path, mode="w") as f:
+ f.write(OmegaConf.to_yaml(conf_to_log))
+ else:
+ from stl.lightning.io import filesystem
+
+ fs = filesystem.get_filesystem(logger.log_dir)
+ conf_to_log = self._make_reproducible_conf()
+ fs.makedirs(logger.log_dir, exist_ok=True)
+ conf_save_path = os.path.join(logger.log_dir, "train_app_conf.yaml")
+ if not fs.exists(conf_save_path):
+ with fs.open(conf_save_path, mode="w") as f:
+ f.write(OmegaConf.to_yaml(conf_to_log))
+
+ log_params()
+ return logger
+
+ def test(self) -> TrainOutput: # pyre-ignore[15]
+ """
+ Triggers PyTorch-lightning's testing phase.
+ """
+ trainer, _ = self._get_trainer()
+ trainer.test(self.module, datamodule=self.datamodule)
+ return TrainOutput(tensorboard_log_dir=self.root_dir)
+
+ def predict(self) -> TrainOutput: # pyre-ignore[15]
+ """
+ Triggers PyTorch-lightning's prediction phase.
+ """
+ trainer, _ = self._get_trainer()
+ trainer.predict(self.module, datamodule=self.datamodule)
+ return TrainOutput(tensorboard_log_dir=self.root_dir)
+
+
+def run_app_in_certain_mode(
+ cfg: TrainAppConf, mode: str, env: str = "oss"
+) -> TrainOutput:
+
+ os.environ["PTV_TRAINER_ENV"] = env
+
+ rank_zero_info(OmegaConf.to_yaml(cfg))
+
+ # TODO: Move this to config and replace with `seed_everything`
+ np.random.seed(0)
+ torch.manual_seed(0)
+ app = hydra.utils.instantiate(cfg, _recursive_=False)
+
+ if mode == "train":
+ rank_zero_info("MODE set to train, run train only.")
+ return app.train()
+ elif mode == "test":
+ rank_zero_info("MODE set to test, run test only.")
+ return app.test()
+ elif mode == "predict":
+ rank_zero_info("MODE set to predict, run train and predict.")
+ app.train()
+ return app.predict()
+ else:
+ # By default, run train and test
+ app.train()
+ return app.test()
+
+
+project_defaults: List[Union[str, Dict[str, str]]] = [
+ "_self_",
+ {"schema/module": "video_classification_module_conf"},
+ {"schema/module/optim": "optim_conf"},
+ {"schema/datamodule": "ptv_video_classification_data_module_conf"},
+ {"datamodule/dataloader": "kinetics_classification"},
+ {"logger": "ptl"},
+ {"datamodule/transforms": "kinetics_classification_slow"},
+ {"module/model": "slow_r50"},
+ {"module/loss": "cross_entropy"},
+ {"module/optim": "sgd"},
+ {"module/metrics": "accuracy"},
+ {"schema/trainer": "trainer"},
+ {"trainer": "cpu"},
+]
+
+
+@dataclass
+class VideoClassificationTrainAppConf(TrainAppConf):
+ _target_: str = get_class_name_str(VideoClassificationTrainApp)
+ datamodule: VideoClassificationDataModuleConf = MISSING
+ module: VideoClassificationModuleConf = MISSING
+ trainer: TrainerConf = MISSING
+
+ # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
+ logger: Any = MISSING
+
+ # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
+ callbacks: Optional[Any] = None
+
+ # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
+ defaults: List[Any] = field(default_factory=lambda: project_defaults)
+
+ # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
+ submitit_conf: Optional[Any] = None
+
+
+cs = ConfigStore()
+cs.store(
+ name="video_classification_train_app_conf",
+ node=VideoClassificationTrainAppConf,
+)
+
+
+@hydra.main(config_path="conf", config_name=None)
+# pyre-ignore[2]
+def submitit_launcher(cfg) -> None:
+
+ print("###################### Train App Config ####################")
+ print(OmegaConf.to_yaml(cfg))
+ print("############################################################")
+
+ submitit_conf = cfg.get("submitit_conf", None)
+ logger_conf = cfg.get("logger", None)
+ assert submitit_conf is not None, "Missing submitit config"
+
+ if logger_conf is not None:
+ assert (
+ logger_conf.save_dir is not None
+ ), "set save_dir in logger conf to a valid path"
+ submitit_dir = os.path.join(logger_conf.save_dir, logger_conf.name)
+ else:
+ assert submitit_conf.log_save_dir is not None
+ submitit_dir = submitit_conf.log_save_dir
+
+ submitit_dir = os.path.join(submitit_dir, "submitit_logs")
+ executor = submitit.AutoExecutor(folder=submitit_dir)
+ job_kwargs = {
+ "slurm_time": submitit_conf.time,
+ "name": cfg.logger.name if logger_conf is not None else submitit_conf.name,
+ "slurm_partition": submitit_conf.partition,
+ "gpus_per_node": cfg.trainer.gpus,
+ "tasks_per_node": cfg.trainer.gpus, # one task per GPU
+ "cpus_per_task": submitit_conf.cpus_per_task,
+ "nodes": cfg.trainer.num_nodes,
+ }
+ if submitit_conf.get("mem", None) is not None:
+ job_kwargs["slurm_mem"] = submitit_conf.mem
+ if submitit_conf.get("constraints", None) is not None:
+ job_kwargs["constraints"] = submitit_conf.constraints
+
+ executor.update_parameters(**job_kwargs)
+ job = executor.submit(run_app_in_certain_mode, cfg, submitit_conf.mode)
+ print("Submitit Job ID:", job.job_id)
+
+
+if __name__ == "__main__":
+ submitit_launcher()
diff --git a/code/pytorchvideo/pytorchvideo_trainer/setup.py b/code/pytorchvideo/pytorchvideo_trainer/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb365b399fe5392c9a4a3bdaa5ecac1f7ee4a51c
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/setup.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+from setuptools import find_packages, setup
+
+
+setup(
+ name="pytorchvideo_trainer",
+ version="0.0.1",
+ license="Apache 2.0",
+ author="Facebook AI",
+ url="https://github.com/facebookresearch/pytorchvideo",
+ description="PyTorch-Lightning trainer powering PyTorchVideo models.",
+ python_requires=">=3.8",
+ install_requires=[
+ "submitit",
+ "pytorchvideo>=0.1.5",
+ ],
+ extras_require={
+ "test": ["coverage", "pytest", "opencv-python"],
+ "dev": [
+ "opencv-python",
+ "black==20.8b1",
+ "sphinx",
+ "isort==4.3.21",
+ "flake8==3.8.1",
+ "flake8-bugbear",
+ "flake8-comprehensions",
+ "pre-commit",
+ "nbconvert",
+ "bs4",
+ "autoflake==1.4",
+ ],
+ "opencv-python": [
+ "opencv-python",
+ ],
+ },
+ packages=find_packages(exclude=("scripts", "tests")),
+)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/__init__.py b/code/pytorchvideo/pytorchvideo_trainer/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/test_conf_datamodule.py b/code/pytorchvideo/pytorchvideo_trainer/tests/test_conf_datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..153be1ca8d92db386b4a53f537d557a77c70d220
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/test_conf_datamodule.py
@@ -0,0 +1,28 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+from hydra.experimental import compose, initialize_config_module
+from hydra.utils import instantiate # @manual
+from pytorchvideo_trainer.datamodule.datamodule import PyTorchVideoDataModule
+
+
+class TestKineticsDataModuleConf(unittest.TestCase):
+ def test_init_with_hydra(self) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ test_conf = compose(
+ config_name="video_classification_train_app_conf",
+ overrides=[
+ "datamodule/dataloader=kinetics_classification",
+ "datamodule/transforms=kinetics_classification_slow",
+ ],
+ )
+ print(test_conf)
+ kinetics_data_module = instantiate(
+ test_conf.datamodule,
+ _recursive_=False,
+ )
+ self.assertIsInstance(kinetics_data_module, PyTorchVideoDataModule)
+ self.assertIsNotNone(kinetics_data_module.transforms["train"])
+ self.assertIsNotNone(kinetics_data_module.transforms["val"])
+ self.assertIsNotNone(kinetics_data_module.transforms["test"])
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/test_conf_module.py b/code/pytorchvideo/pytorchvideo_trainer/tests/test_conf_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b5d7937e03e500bd99fb5dc47917b4787b88327
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/test_conf_module.py
@@ -0,0 +1,56 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import hydra
+from hydra.experimental import compose, initialize_config_module
+from pytorchvideo_trainer.module.byol import BYOLModule
+from pytorchvideo_trainer.module.moco_v2 import MOCOV2Module
+from pytorchvideo_trainer.module.simclr import SimCLRModule
+from pytorchvideo_trainer.module.video_classification import VideoClassificationModule
+
+
+class TestVideoClassificationModuleConf(unittest.TestCase):
+ def test_init_with_hydra(self) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ test_conf = compose(
+ config_name="video_classification_train_app_conf",
+ overrides=["module/model=slow_r50"],
+ )
+ test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False)
+ self.assertIsInstance(test_module, VideoClassificationModule)
+ self.assertIsNotNone(test_module.model)
+
+
+class TestVideoSimCLRModuleConf(unittest.TestCase):
+ def test_init_with_hydra(self) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ test_conf = compose(
+ config_name="simclr_train_app_conf",
+ )
+ test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False)
+ self.assertIsInstance(test_module, SimCLRModule)
+ self.assertIsNotNone(test_module.model)
+
+
+class TestVideoBYOLModuleConf(unittest.TestCase):
+ def test_init_with_hydra(self) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ test_conf = compose(
+ config_name="byol_train_app_conf",
+ )
+ test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False)
+ self.assertIsInstance(test_module, BYOLModule)
+ self.assertIsNotNone(test_module.model)
+
+
+class TestVideoMOCOV2ModuleConf(unittest.TestCase):
+ def test_init_with_hydra(self) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ test_conf = compose(
+ config_name="moco_v2_train_app_conf",
+ # overrides=["module/model=resnet"],
+ )
+ test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False)
+ self.assertIsInstance(test_module, MOCOV2Module)
+ self.assertIsNotNone(test_module.model)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_byol.py b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_byol.py
new file mode 100644
index 0000000000000000000000000000000000000000..b879ef1d49b9144754ace2dc34cafa41959052eb
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_byol.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+# pyre-strict
+from torchrecipes.core.base_train_app import BaseTrainApp
+from util import (
+ BaseTrainAppTestCase,
+ create_small_kinetics_dataset,
+ run_locally,
+ tempdir,
+)
+
+
+class TestBYOLTrainApp(BaseTrainAppTestCase):
+ def get_train_app(
+ self,
+ root_dir: str,
+ fast_dev_run: bool = True,
+ logger: bool = False,
+ ) -> BaseTrainApp:
+ create_small_kinetics_dataset(root_dir)
+ overrides = [
+ f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
+ f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
+ "datamodule.dataloader.train.num_workers=0",
+ "datamodule.dataloader.val.num_workers=0",
+ "datamodule.dataloader.test.num_workers=0",
+ "module.knn_memory.length=50",
+ "module.knn_memory.knn_k=2",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ "trainer.logger=false",
+ ]
+ app = self.create_app_from_hydra(
+ config_module="pytorchvideo_trainer.conf",
+ config_name="byol_train_app_conf",
+ overrides=overrides,
+ )
+ trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger}
+ self.mock_trainer_params(app, trainer_overrides)
+ return app
+
+ @run_locally
+ @tempdir
+ def test_byol_app_train_test_30_views(self, root_dir: str) -> None:
+ train_app = self.get_train_app(
+ root_dir=root_dir, fast_dev_run=False, logger=False
+ )
+ output = train_app.train()
+ self.assertIsNotNone(output)
+ output = train_app.test()
+ self.assertIsNotNone(output)
+
+ video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None)
+ num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10)
+ num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3)
+ self.assertIsNotNone(video_clips_cnts)
+ for _, sample_cnts in video_clips_cnts.items():
+ self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_moco_v2.py b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_moco_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..adfbfed52ee149e799b4d4587c8243a356b3fd90
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_moco_v2.py
@@ -0,0 +1,64 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+# pyre-strict
+from torchrecipes.core.base_train_app import BaseTrainApp
+from util import (
+ BaseTrainAppTestCase,
+ create_small_kinetics_dataset,
+ run_locally,
+ tempdir,
+)
+
+
+class TestMOCOV2TrainApp(BaseTrainAppTestCase):
+ def get_train_app(
+ self,
+ root_dir: str,
+ fast_dev_run: bool = True,
+ logger: bool = False,
+ ) -> BaseTrainApp:
+ create_small_kinetics_dataset(root_dir)
+ overrides = [
+ f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
+ f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
+ "datamodule.dataloader.train.num_workers=0",
+ "datamodule.dataloader.val.num_workers=0",
+ "datamodule.dataloader.test.num_workers=0",
+ "module.knn_memory.length=50",
+ "module.knn_memory.knn_k=2",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ "trainer.logger=false",
+ ]
+
+ app = self.create_app_from_hydra(
+ config_module="pytorchvideo_trainer.conf",
+ config_name="moco_v2_train_app_conf",
+ overrides=overrides,
+ )
+ trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger}
+ self.mock_trainer_params(app, trainer_overrides)
+ return app
+
+ @run_locally
+ @tempdir
+ def test_moco_v2_app_train_test_30_views(self, root_dir: str) -> None:
+ train_app = self.get_train_app(
+ root_dir=root_dir, fast_dev_run=False, logger=False
+ )
+ output = train_app.train()
+ self.assertIsNotNone(output)
+ output = train_app.test()
+ self.assertIsNotNone(output)
+
+ video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None)
+ num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10)
+ num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3)
+ self.assertIsNotNone(video_clips_cnts)
+ for _, sample_cnts in video_clips_cnts.items():
+ self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_module_all.py b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_module_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..95f8c1a571b722e6b6f8612c34d1bb9d6232617a
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_module_all.py
@@ -0,0 +1,125 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+from typing import Any
+
+import hydra
+from hydra import compose, initialize_config_module
+from hydra.utils import instantiate # @manual
+from omegaconf import OmegaConf
+from pytorch_lightning import Trainer
+from pytorchvideo_trainer.datamodule.datamodule import VideoClassificationDataModuleConf
+from pytorchvideo_trainer.train_app import VideoClassificationTrainAppConf
+from util import create_small_kinetics_dataset, run_locally, tempdir
+
+
+class TestMain(unittest.TestCase):
+ # pyre-fixme[3]: Return annotation cannot be `Any`.
+ def get_datamodule(self, cfg: VideoClassificationDataModuleConf) -> Any:
+ test_data_module = instantiate(
+ cfg,
+ _recursive_=False,
+ )
+ return test_data_module
+
+ def train(self, cfg: VideoClassificationTrainAppConf) -> None:
+ print(OmegaConf.to_yaml(cfg))
+ test_module = hydra.utils.instantiate(cfg.module, _recursive_=False)
+ test_data_module = self.get_datamodule(cfg.datamodule)
+ # pyre-fixme[6]: Expected `SupportsKeysAndGetItem[Variable[_KT],
+ # Variable[_VT]]` for 1st param but got `TrainerConf`.
+ trainer_params = dict(cfg.trainer)
+ trainer_params["logger"] = True
+ trainer_params["checkpoint_callback"] = False
+ trainer_params["fast_dev_run"] = True
+ pl_trainer = Trainer(**trainer_params)
+ pl_trainer.fit(model=test_module, datamodule=test_data_module)
+
+ @run_locally
+ @tempdir
+ def test_train_video_model(self, root_dir: str) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ create_small_kinetics_dataset(root_dir)
+ # Config is relative to a module
+ cfg = compose(
+ config_name="video_classification_train_app_conf",
+ overrides=[
+ f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
+ f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
+ "datamodule.dataloader.train.num_workers=0",
+ "datamodule.dataloader.val.num_workers=0",
+ "datamodule.dataloader.test.num_workers=0",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ "+module/lr_scheduler=cosine_with_warmup",
+ "trainer.logger=true",
+ ],
+ )
+ self.assertEqual(cfg.trainer.max_epochs, 1)
+
+ self.train(cfg)
+
+ @run_locally
+ @tempdir
+ def test_train_video_model_simclr(self, root_dir: str) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ create_small_kinetics_dataset(root_dir)
+ # Config is relative to a module
+ cfg = compose(
+ config_name="simclr_train_app_conf",
+ overrides=[
+ f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
+ f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
+ "datamodule.dataloader.train.num_workers=0",
+ "datamodule.dataloader.val.num_workers=0",
+ "datamodule.dataloader.test.num_workers=0",
+ "module.knn_memory.length=50",
+ "module.knn_memory.knn_k=2",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ "trainer.logger=true",
+ ],
+ )
+ self.assertEqual(cfg.trainer.max_epochs, 1)
+
+ self.train(cfg)
+
+ @run_locally
+ @tempdir
+ def test_train_video_model_byol(self, root_dir: str) -> None:
+ with initialize_config_module(config_module="pytorchvideo_trainer.conf"):
+ create_small_kinetics_dataset(root_dir)
+ # Config is relative to a module
+ cfg = compose(
+ config_name="byol_train_app_conf",
+ overrides=[
+ f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
+ f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
+ "datamodule.dataloader.train.num_workers=0",
+ "datamodule.dataloader.val.num_workers=0",
+ "datamodule.dataloader.test.num_workers=0",
+ "module.knn_memory.length=50",
+ "module.knn_memory.knn_k=2",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ "trainer.logger=true",
+ ],
+ )
+ self.assertEqual(cfg.trainer.max_epochs, 1)
+
+ self.train(cfg)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_simclr.py b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac47880aabd3940e30453573a2c4481a2b2ec7d
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_simclr.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+# pyre-strict
+from torchrecipes.core.base_train_app import BaseTrainApp
+from util import (
+ BaseTrainAppTestCase,
+ create_small_kinetics_dataset,
+ run_locally,
+ tempdir,
+)
+
+
+class TestSimCLRTrainApp(BaseTrainAppTestCase):
+ def get_train_app(
+ self,
+ root_dir: str,
+ fast_dev_run: bool = True,
+ logger: bool = False,
+ ) -> BaseTrainApp:
+ create_small_kinetics_dataset(root_dir)
+ overrides = [
+ f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
+ f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
+ "datamodule.dataloader.train.num_workers=0",
+ "datamodule.dataloader.val.num_workers=0",
+ "datamodule.dataloader.test.num_workers=0",
+ "module.knn_memory.length=50",
+ "module.knn_memory.knn_k=2",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ "trainer.logger=false",
+ ]
+ app = self.create_app_from_hydra(
+ config_module="pytorchvideo_trainer.conf",
+ config_name="simclr_train_app_conf",
+ overrides=overrides,
+ )
+ trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger}
+ self.mock_trainer_params(app, trainer_overrides)
+ return app
+
+ @run_locally
+ @tempdir
+ def test_simclr_app_train_test_30_views(self, root_dir: str) -> None:
+ train_app = self.get_train_app(
+ root_dir=root_dir, fast_dev_run=False, logger=False
+ )
+ output = train_app.train()
+ self.assertIsNotNone(output)
+ output = train_app.test()
+ self.assertIsNotNone(output)
+
+ video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None)
+ num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10)
+ num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3)
+ self.assertIsNotNone(video_clips_cnts)
+ for _, sample_cnts in video_clips_cnts.items():
+ self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_video_classification.py b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_video_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..685c6e1bd90f545e9ccd0c3dc68d220d9dc8d6bd
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/test_task_video_classification.py
@@ -0,0 +1,92 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+# pyre-strict
+from torchrecipes.core.base_train_app import BaseTrainApp
+from util import (
+ BaseTrainAppTestCase,
+ create_small_kinetics_dataset,
+ run_locally,
+ tempdir,
+)
+
+
+class TestVideoClassificationTrainApp(BaseTrainAppTestCase):
+ def get_train_app(
+ self,
+ root_dir: str,
+ precise_bn_num_batches: int = 0,
+ fast_dev_run: bool = True,
+ logger: bool = False,
+ ) -> BaseTrainApp:
+ create_small_kinetics_dataset(root_dir)
+ overrides = [
+ f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
+ f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
+ f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
+ f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
+ "datamodule.dataloader.train.num_workers=0",
+ "datamodule.dataloader.val.num_workers=0",
+ "datamodule.dataloader.test.num_workers=0",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ "+module/lr_scheduler=cosine_with_warmup",
+ "trainer.logger=false",
+ ]
+ if precise_bn_num_batches > 0:
+ overrides.extend(
+ [
+ "+callbacks=precise_bn",
+ f"callbacks.precise_bn.num_batches={precise_bn_num_batches}",
+ "datamodule.dataloader.train.batch_size=2",
+ "datamodule.dataloader.val.batch_size=2",
+ "datamodule.dataloader.test.batch_size=2",
+ ]
+ )
+ app = self.create_app_from_hydra(
+ config_module="pytorchvideo_trainer.conf",
+ config_name="video_classification_train_app_conf",
+ overrides=overrides,
+ )
+ trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger}
+ self.mock_trainer_params(app, trainer_overrides)
+ return app
+
+ @run_locally
+ @tempdir
+ def test_video_classification_app_train(self, root_dir: str) -> None:
+ train_app = self.get_train_app(root_dir=root_dir, logger=False)
+ output = train_app.train()
+ self.assertIsNotNone(output)
+
+ @run_locally
+ @tempdir
+ def test_video_classification_app_train_with_precise_bn(
+ self, root_dir: str
+ ) -> None:
+ train_app = self.get_train_app(
+ root_dir=root_dir, precise_bn_num_batches=2, logger=False
+ )
+ output = train_app.train()
+ self.assertIsNotNone(output)
+
+ @run_locally
+ @tempdir
+ def test_video_classification_app_test(self, root_dir: str) -> None:
+ train_app = self.get_train_app(root_dir=root_dir)
+ output = train_app.test()
+ self.assertIsNotNone(output)
+
+ @run_locally
+ @tempdir
+ def test_video_classification_app_test_30_views(self, root_dir: str) -> None:
+ train_app = self.get_train_app(root_dir=root_dir, fast_dev_run=False)
+ train_app.test()
+ video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None)
+ num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10)
+ num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3)
+ self.assertIsNotNone(video_clips_cnts)
+ for _, sample_cnts in video_clips_cnts.items():
+ self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts)
diff --git a/code/pytorchvideo/pytorchvideo_trainer/tests/util.py b/code/pytorchvideo/pytorchvideo_trainer/tests/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..8796154e40edddd3bfc23f49ba6bfe5b1bb5f329
--- /dev/null
+++ b/code/pytorchvideo/pytorchvideo_trainer/tests/util.py
@@ -0,0 +1,163 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import csv
+import os
+from functools import wraps
+from tempfile import TemporaryDirectory
+from typing import Any, Callable, Dict, List, Optional
+from unittest import mock
+
+import testslide
+import torch
+import torchvision.io as io
+from hydra import compose, initialize_config_module
+from hydra.utils import instantiate
+from omegaconf import OmegaConf
+from torchrecipes.core.base_train_app import BaseTrainApp, TrainOutput
+
+
+def create_small_kinetics_dataset(root_dir: str) -> None:
+ """
+ A test utility function to create a small Kinetics like dataset
+
+ Args:
+ root_dir(str): The directory to create the dataset in.
+ Typically, a temporary directory is used.
+ """
+ video_codec = "libx264rgb"
+ options = {"crf": "0"}
+ height: int = 250
+ width: int = 250
+ num_frames = 20
+ fps = 5
+ data = create_dummy_video_frames(num_frames, height, width)
+
+ train_data = [
+ ["a.mp4", "308"],
+ ["b.mp4", "298"],
+ ["c.mp4", "240"],
+ ["d.mp4", "363"],
+ ]
+
+ val_data = [
+ ["a.mp4", "151"],
+ ]
+
+ for i in range(4):
+ io.write_video(
+ os.path.join(root_dir, train_data[i][0]),
+ data,
+ fps=fps,
+ video_codec=video_codec,
+ options=options,
+ )
+
+ train_file = os.path.join(root_dir, "train.csv")
+ write_single_csv_file(train_file, train_data)
+
+ val_file = os.path.join(root_dir, "val.csv")
+ write_single_csv_file(val_file, val_data)
+
+
+# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
+def write_single_csv_file(file_name: str, data: List[Any]) -> None:
+ with open(file_name, "w+", newline="") as csvfile:
+ data_writer = csv.writer(
+ # pyre-fixme[6]: Expected `_Writer` for 1st param but got `TextIOWrapper`.
+ csvfile,
+ delimiter=" ",
+ )
+ for row in data:
+ data_writer.writerow(row)
+
+
+# pyre-fixme[3]
+def create_dummy_video_frames(num_frames: int, height: int, width: int):
+ y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
+ data = []
+ for i in range(num_frames):
+ xc = float(i) / num_frames
+ yc = 1 - float(i) / (2 * num_frames)
+ d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
+ data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())
+ return torch.stack(data, 0)
+
+
+def run_locally(func: Callable) -> Callable: # pyre-ignore[24]
+ """A decorator to run unittest locally."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs): # pyre-ignore[2,3]
+ with mock.patch(
+ "torch.distributed.is_available",
+ return_value=False,
+ ):
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def tempdir(func: Callable) -> Callable: # pyre-ignore[24]
+ """A decorator for creating a tempory directory that
+ is cleaned up after function execution."""
+
+ @wraps(func)
+ def wrapper(self, *args, **kwargs): # pyre-ignore[2,3]
+ with TemporaryDirectory() as temp:
+ return func(self, temp, *args, **kwargs)
+
+ return wrapper
+
+
+def get_mock_init_trainer_params(
+ overrides: Optional[Dict[str, Any]] = None,
+) -> Callable[..., Dict[str, Any]]:
+ """
+ Order of trainer_params setting in unit test:
+ - First call original function, which sets params from config
+ - Then override some params to disable logger and checkpoint
+ - Apply any test-specific overrides.
+ """
+
+ def mock_init_trainer_params(
+ original: Callable[..., Dict[str, Any]],
+ ) -> Dict[str, Any]:
+ trainer_params = original()
+
+ trainer_params["logger"] = False
+ trainer_params["enable_checkpointing"] = False
+ trainer_params["fast_dev_run"] = True
+
+ if overrides:
+ trainer_params.update(overrides)
+
+ return trainer_params
+
+ return mock_init_trainer_params
+
+
+class BaseTrainAppTestCase(testslide.TestCase):
+ """All Standard TrainApp unit tests should inherit from this class."""
+
+ def mock_trainer_params(
+ self, app: BaseTrainApp, overrides: Optional[Dict[str, Any]] = None
+ ) -> None:
+ self.mock_callable(
+ app, "_init_trainer_params", allow_private=True
+ ).with_wrapper(get_mock_init_trainer_params(overrides))
+
+ def create_app_from_hydra(
+ self,
+ config_module: str,
+ config_name: str,
+ overrides: Optional[List[str]] = None,
+ ) -> BaseTrainApp:
+ with initialize_config_module(config_module=config_module):
+ cfg = compose(config_name=config_name, overrides=overrides or [])
+ print(OmegaConf.to_yaml(cfg))
+ return instantiate(cfg, _recursive_=False)
+
+ def assert_train_output(self, output: TrainOutput) -> None:
+ self.assertIsNotNone(output)
+ # Ensure logger is set to False in test to avoid dependency on Manifold
+ self.assertIsNone(output.tensorboard_log_dir)
diff --git a/code/pytorchvideo/setup.cfg b/code/pytorchvideo/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..0f73488f2c148502c88d9fecfddb3c96954a3b3c
--- /dev/null
+++ b/code/pytorchvideo/setup.cfg
@@ -0,0 +1,8 @@
+[isort]
+line_length = 88
+multi_line_output = 3
+include_trailing_comma = True
+force_grid_warp = 0
+default_section = THIRDPARTY
+lines_after_imports = 2
+combine_as_imports = True
diff --git a/code/pytorchvideo/setup.py b/code/pytorchvideo/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..abe7eefccfac7887155d5d094de7a60a5e9fc064
--- /dev/null
+++ b/code/pytorchvideo/setup.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+
+from setuptools import find_packages, setup
+
+
+def get_version():
+ init_py_path = os.path.join(
+ os.path.abspath(os.path.dirname(__file__)), "pytorchvideo", "__init__.py"
+ )
+ init_py = open(init_py_path, "r").readlines()
+ version_line = [
+ lines.strip() for lines in init_py if lines.startswith("__version__")
+ ][0]
+ version = version_line.split("=")[-1].strip().strip("'\"")
+
+ # Used by CI to build nightly packages. Users should never use it.
+ # To build a nightly wheel, run:
+ # BUILD_NIGHTLY=1 python setup.py sdist
+ if os.getenv("BUILD_NIGHTLY", "0") == "1":
+ from datetime import datetime
+
+ date_str = datetime.today().strftime("%Y%m%d")
+ # pip can perform proper comparison for ".post" suffix,
+ # i.e., "1.1.post1234" >= "1.1"
+ version = version + ".post" + date_str
+
+ new_init_py = [l for l in init_py if not l.startswith("__version__")]
+ new_init_py.append('__version__ = "{}"\n'.format(version))
+ with open(init_py_path, "w") as f:
+ f.write("".join(new_init_py))
+
+ return version
+
+
+def get_name():
+ name = "pytorchvideo"
+ if os.getenv("BUILD_NIGHTLY", "0") == "1":
+ name += "-nightly"
+ return name
+
+
+setup(
+ name=get_name(),
+ version=get_version(),
+ license="Apache 2.0",
+ author="Facebook AI",
+ url="https://github.com/facebookresearch/pytorchvideo",
+ description="A video understanding deep learning library.",
+ python_requires=">=3.7",
+ install_requires=[
+ "fvcore",
+ "av",
+ "parameterized",
+ "iopath",
+ "networkx",
+ ],
+ extras_require={
+ "test": ["coverage", "pytest", "opencv-python", "decord"],
+ "dev": [
+ "opencv-python",
+ "decord",
+ "black==20.8b1",
+ "sphinx",
+ "isort==4.3.21",
+ "flake8==3.8.1",
+ "flake8-bugbear",
+ "flake8-comprehensions",
+ "pre-commit",
+ "nbconvert",
+ "bs4",
+ "autoflake==1.4",
+ ],
+ "opencv-python": [
+ "opencv-python",
+ ],
+ },
+ packages=find_packages(exclude=("scripts", "tests")),
+)
diff --git a/code/pytorchvideo/tests/README.md b/code/pytorchvideo/tests/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5d15b7c1e273e2e9a51a41c99474084e25dc0899
--- /dev/null
+++ b/code/pytorchvideo/tests/README.md
@@ -0,0 +1,21 @@
+## Unit Tests
+
+
+Before running the tests, please ensure that you installed the necessary additional test dependencies.
+If not installed, check the [install-README](https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md) on how to do it.
+
+Use the the following command to run the tests:
+```
+# From root of the project
+python -m unittest discover -v -s ./tests
+```
+
+To generate the coverage reports, please run the following command:
+```
+#Install Coverage using
+pip install coverage
+
+# From root of the project
+coverage run -m unittest discover -v -s tests
+```
+
diff --git a/code/pytorchvideo/tests/__init__.py b/code/pytorchvideo/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7f19c6c00a4ac3f2f2bc66f892e44bcbd72612
--- /dev/null
+++ b/code/pytorchvideo/tests/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
diff --git a/code/pytorchvideo/tests/benchmark_accelerator_efficient_blocks.py b/code/pytorchvideo/tests/benchmark_accelerator_efficient_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b348da7eef66f3258b7a3d63f64bd54a9f4ef05
--- /dev/null
+++ b/code/pytorchvideo/tests/benchmark_accelerator_efficient_blocks.py
@@ -0,0 +1,446 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import unittest
+from typing import Callable, Tuple
+
+import torch
+import torch.nn as nn
+from fvcore.common.benchmark import benchmark
+from pytorchvideo.layers.accelerator.mobile_cpu.convolutions import (
+ Conv3d3x3x3DwBnAct,
+ Conv3dPwBnAct,
+)
+from pytorchvideo.models.accelerator.mobile_cpu.residual_blocks import (
+ X3dBottleneckBlock,
+)
+from torch.utils.mobile_optimizer import optimize_for_mobile
+
+
+TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
+if TORCH_VERSION >= (1, 11):
+ from torch.ao.quantization import (
+ convert,
+ DeQuantStub,
+ fuse_modules,
+ get_default_qconfig,
+ prepare,
+ QuantStub,
+ # quantize_fx
+ )
+else:
+ from torch.quantization import (
+ convert,
+ DeQuantStub,
+ fuse_modules,
+ get_default_qconfig,
+ prepare,
+ QuantStub,
+ # quantize_fx
+ )
+
+
+class TestBenchmarkEfficientBlocks(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_benchmark_conv3d_pw_bn_relu(self, num_iters: int = 20) -> None:
+ """
+ Benchmark Conv3dPwBnAct with ReLU activation.
+ Note efficient block Conv3dPwBnAct is designed for mobile cpu with qnnpack
+ backend, and benchmarking on server with another backend (e.g., fbgemm) may
+ have different latency result compared to running on mobile cpu with qnnpack.
+ Running on x86 based server cpu with qnnpack may also have different latency as
+ running on mobile cpu with qnnpack, as qnnpack is optimized for
+ ARM based mobile cpu.
+ Args:
+ num_iters (int): number of iterations to perform benchmarking.
+ """
+
+ torch.backends.quantized.engine = "qnnpack"
+ kwargs_list = [
+ {
+ "mode": "original",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "out_channels": 108,
+ "quantize": False,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "out_channels": 108,
+ "quantize": False,
+ },
+ {
+ "mode": "original",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "out_channels": 108,
+ "quantize": True,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "out_channels": 108,
+ "quantize": True,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "out_channels": 108,
+ "quantize": True,
+ "native_conv3d_op_qnnpack": True,
+ },
+ ]
+
+ def _benchmark_conv3d_pw_bn_relu_forward(**kwargs) -> Callable:
+ assert kwargs["mode"] in ("original", "deployable"), (
+ "kwargs['mode'] must be either 'original' or 'deployable',"
+ "but got {}.".format(kwargs["mode"])
+ )
+ input_tensor = torch.randn((kwargs["input_blob_size"]))
+ conv_block = Conv3dPwBnAct(
+ kwargs["in_channels"],
+ kwargs["out_channels"],
+ use_bn=False, # assume BN has already been fused for forward
+ )
+
+ if kwargs["mode"] == "deployable":
+ native_conv3d_op_qnnpack = kwargs.get("native_conv3d_op_qnnpack", False)
+ conv_block.convert(
+ kwargs["input_blob_size"],
+ convert_for_quantize=kwargs["quantize"],
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ conv_block.eval()
+
+ def func_to_benchmark_dummy() -> None:
+ return
+
+ if kwargs["quantize"] is True:
+ if kwargs["mode"] == "original": # manually fuse conv and relu
+ conv_block.kernel = fuse_modules(
+ conv_block.kernel, ["conv", "act.act"]
+ )
+ conv_block = nn.Sequential(
+ QuantStub(),
+ conv_block,
+ DeQuantStub(),
+ )
+
+ conv_block.qconfig = get_default_qconfig("qnnpack")
+ conv_block = prepare(conv_block)
+ try:
+ conv_block = convert(conv_block)
+
+ except Exception as e:
+ logging.info(
+ "benchmark_conv3d_pw_bn_relu: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+
+ return func_to_benchmark_dummy
+ try:
+ traced_model = torch.jit.trace(conv_block, input_tensor, strict=False)
+ except Exception as e:
+ logging.info(
+ "benchmark_conv3d_pw_bn_relu: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+
+ return func_to_benchmark_dummy
+
+ if kwargs["quantize"] is False:
+ traced_model = optimize_for_mobile(traced_model)
+
+ logging.info(f"model arch: {traced_model}")
+
+ def func_to_benchmark() -> None:
+ try:
+ _ = traced_model(input_tensor)
+ except Exception as e:
+ logging.info(
+ "benchmark_conv3d_pw_bn_relu: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+
+ return
+
+ return func_to_benchmark
+
+ benchmark(
+ _benchmark_conv3d_pw_bn_relu_forward,
+ "benchmark_conv3d_pw_bn_relu",
+ kwargs_list,
+ num_iters=num_iters,
+ warmup_iters=2,
+ )
+
+ self.assertTrue(True)
+
+ def test_benchmark_conv3d_3x3x3_dw_bn_relu(self, num_iters: int = 20) -> None:
+ """
+ Benchmark Conv3d3x3x3DwBnAct with ReLU activation.
+ Note efficient block Conv3d3x3x3DwBnAct is designed for mobile cpu with qnnpack
+ backend, and benchmarking on server with another backend (e.g., fbgemm) may have
+ different latency result compared as running on mobile cpu.
+ Args:
+ num_iters (int): number of iterations to perform benchmarking.
+ """
+ torch.backends.quantized.engine = "qnnpack"
+ kwargs_list = [
+ {
+ "mode": "original",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "quantize": False,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "quantize": False,
+ },
+ {
+ "mode": "original",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "quantize": True,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "quantize": True,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 40, 40),
+ "in_channels": 48,
+ "quantize": True,
+ "native_conv3d_op_qnnpack": True,
+ },
+ ]
+
+ def _benchmark_conv3d_3x3x3_dw_bn_relu_forward(**kwargs) -> Callable:
+ assert kwargs["mode"] in ("original", "deployable"), (
+ "kwargs['mode'] must be either 'original' or 'deployable',"
+ "but got {}.".format(kwargs["mode"])
+ )
+ input_tensor = torch.randn((kwargs["input_blob_size"]))
+ conv_block = Conv3d3x3x3DwBnAct(
+ kwargs["in_channels"],
+ use_bn=False, # assume BN has already been fused for forward
+ )
+
+ def func_to_benchmark_dummy() -> None:
+ return
+
+ if kwargs["mode"] == "deployable":
+ native_conv3d_op_qnnpack = kwargs.get("native_conv3d_op_qnnpack", False)
+ conv_block.convert(
+ kwargs["input_blob_size"],
+ convert_for_quantize=kwargs["quantize"],
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ conv_block.eval()
+ if kwargs["quantize"] is True:
+ if kwargs["mode"] == "original": # manually fuse conv and relu
+ conv_block.kernel = fuse_modules(
+ conv_block.kernel, ["conv", "act.act"]
+ )
+ conv_block = nn.Sequential(
+ QuantStub(),
+ conv_block,
+ DeQuantStub(),
+ )
+
+ conv_block.qconfig = get_default_qconfig("qnnpack")
+ conv_block = prepare(conv_block)
+ try:
+ conv_block = convert(conv_block)
+ except Exception as e:
+ logging.info(
+ "benchmark_conv3d_3x3x3_dw_bn_relu: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+
+ return func_to_benchmark_dummy
+ try:
+ traced_model = torch.jit.trace(conv_block, input_tensor, strict=False)
+ except Exception as e:
+ logging.info(
+ "benchmark_conv3d_3x3x3_dw_bn_relu: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+
+ return func_to_benchmark_dummy
+ if kwargs["quantize"] is False:
+ traced_model = optimize_for_mobile(traced_model)
+
+ logging.info(f"model arch: {traced_model}")
+
+ def func_to_benchmark() -> None:
+ try:
+ _ = traced_model(input_tensor)
+ except Exception as e:
+ logging.info(
+ "benchmark_conv3d_3x3x3_dw_bn_relu: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+ return
+
+ return func_to_benchmark
+
+ benchmark(
+ _benchmark_conv3d_3x3x3_dw_bn_relu_forward,
+ "benchmark_conv3d_3x3x3_dw_bn_relu",
+ kwargs_list,
+ num_iters=num_iters,
+ warmup_iters=2,
+ )
+
+ self.assertTrue(True)
+
+ def test_benchmark_x3d_bottleneck_block(self, num_iters: int = 20) -> None:
+ """
+ Benchmark X3dBottleneckBlock.
+ Note efficient block X3dBottleneckBlock is designed for mobile cpu with qnnpack
+ backend, and benchmarking on server/laptop may have different latency result
+ compared to running on mobile cpu.
+ Args:
+ num_iters (int): number of iterations to perform benchmarking.
+ """
+ torch.backends.quantized.engine = "qnnpack"
+ kwargs_list = [
+ {
+ "mode": "original",
+ "input_blob_size": (1, 48, 4, 20, 20),
+ "in_channels": 48,
+ "mid_channels": 108,
+ "out_channels": 48,
+ "quantize": False,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 20, 20),
+ "in_channels": 48,
+ "mid_channels": 108,
+ "out_channels": 48,
+ "quantize": False,
+ },
+ {
+ "mode": "original",
+ "input_blob_size": (1, 48, 4, 20, 20),
+ "in_channels": 48,
+ "mid_channels": 108,
+ "out_channels": 48,
+ "quantize": True,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 20, 20),
+ "in_channels": 48,
+ "mid_channels": 108,
+ "out_channels": 48,
+ "quantize": True,
+ },
+ {
+ "mode": "deployable",
+ "input_blob_size": (1, 48, 4, 20, 20),
+ "in_channels": 48,
+ "mid_channels": 108,
+ "out_channels": 48,
+ "quantize": True,
+ "native_conv3d_op_qnnpack": True,
+ },
+ ]
+
+ def _benchmark_x3d_bottleneck_forward(**kwargs) -> Callable:
+ assert kwargs["mode"] in ("original", "deployable"), (
+ "kwargs['mode'] must be either 'original' or 'deployable',"
+ "but got {}.".format(kwargs["mode"])
+ )
+ input_tensor = torch.randn((kwargs["input_blob_size"]))
+ conv_block = X3dBottleneckBlock(
+ kwargs["in_channels"],
+ kwargs["mid_channels"],
+ kwargs["out_channels"],
+ use_bn=(False, False, False), # Assume BN has been fused for forward
+ )
+
+ if kwargs["mode"] == "deployable":
+ native_conv3d_op_qnnpack = kwargs.get("native_conv3d_op_qnnpack", False)
+ conv_block.convert(
+ kwargs["input_blob_size"],
+ convert_for_quantize=kwargs["quantize"],
+ native_conv3d_op_qnnpack=native_conv3d_op_qnnpack,
+ )
+ conv_block.eval()
+
+ def func_to_benchmark_dummy() -> None:
+ return
+
+ if kwargs["quantize"] is True:
+ conv_block = nn.Sequential(
+ QuantStub(),
+ conv_block,
+ DeQuantStub(),
+ )
+
+ conv_block.qconfig = get_default_qconfig("qnnpack")
+ conv_block = prepare(conv_block)
+ try:
+ conv_block = convert(conv_block)
+ traced_model = torch.jit.trace(
+ conv_block, input_tensor, strict=False
+ )
+ except Exception as e:
+ logging.info(
+ "benchmark_x3d_bottleneck_forward: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+
+ return func_to_benchmark_dummy
+
+ try:
+ traced_model = torch.jit.trace(conv_block, input_tensor, strict=False)
+ except Exception as e:
+ logging.info(
+ "benchmark_x3d_bottleneck_forward: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+
+ return func_to_benchmark_dummy
+
+ if kwargs["quantize"] is False:
+ traced_model = optimize_for_mobile(traced_model)
+
+ logging.info(f"model arch: {traced_model}")
+
+ def func_to_benchmark() -> None:
+ try:
+ _ = traced_model(input_tensor)
+ except Exception as e:
+ logging.info(
+ "benchmark_x3d_bottleneck_forward: "
+ "catch exception '{}' with kwargs of {}".format(e, kwargs)
+ )
+ return
+
+ return func_to_benchmark
+
+ benchmark(
+ _benchmark_x3d_bottleneck_forward,
+ "benchmark_x3d_bottleneck_forward",
+ kwargs_list,
+ num_iters=num_iters,
+ warmup_iters=2,
+ )
+
+ self.assertTrue(True)
diff --git a/code/pytorchvideo/tests/benchmark_transforms.py b/code/pytorchvideo/tests/benchmark_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..f94aa493a96039a29a29341591540f7cb0f90df0
--- /dev/null
+++ b/code/pytorchvideo/tests/benchmark_transforms.py
@@ -0,0 +1,82 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+from typing import Callable
+
+import torch
+from fvcore.common.benchmark import benchmark
+from pytorchvideo.data.utils import thwc_to_cthw
+from pytorchvideo.transforms.functional import short_side_scale
+from utils import create_dummy_video_frames
+
+
+class TestBenchmarkTransforms(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_benchmark_short_side_scale_pytorch(self, num_iters: int = 10) -> None:
+ """
+ Benchmark scale operation with pytorch backend.
+ Args:
+ num_iters (int): number of iterations to perform benchmarking.
+ """
+ kwargs_list = [
+ {"temporal_size": 8, "ori_spatial_size": (128, 128), "dst_short_size": 112},
+ {
+ "temporal_size": 16,
+ "ori_spatial_size": (128, 128),
+ "dst_short_size": 112,
+ },
+ {
+ "temporal_size": 32,
+ "ori_spatial_size": (128, 128),
+ "dst_short_size": 112,
+ },
+ {"temporal_size": 8, "ori_spatial_size": (256, 256), "dst_short_size": 224},
+ {
+ "temporal_size": 16,
+ "ori_spatial_size": (256, 256),
+ "dst_short_size": 224,
+ },
+ {
+ "temporal_size": 32,
+ "ori_spatial_size": (256, 256),
+ "dst_short_size": 224,
+ },
+ {"temporal_size": 8, "ori_spatial_size": (320, 320), "dst_short_size": 224},
+ {
+ "temporal_size": 16,
+ "ori_spatial_size": (320, 320),
+ "dst_short_size": 224,
+ },
+ {
+ "temporal_size": 32,
+ "ori_spatial_size": (320, 320),
+ "dst_short_size": 224,
+ },
+ ]
+
+ def _init_benchmark_short_side_scale(**kwargs) -> Callable:
+ x = thwc_to_cthw(
+ create_dummy_video_frames(
+ kwargs["temporal_size"],
+ kwargs["ori_spatial_size"][0],
+ kwargs["ori_spatial_size"][1],
+ )
+ ).to(dtype=torch.float32)
+
+ def func_to_benchmark() -> None:
+ _ = short_side_scale(x, kwargs["dst_short_size"])
+ return
+
+ return func_to_benchmark
+
+ benchmark(
+ _init_benchmark_short_side_scale,
+ "benchmark_short_side_scale_pytorch",
+ kwargs_list,
+ num_iters=num_iters,
+ warmup_iters=2,
+ )
+ self.assertTrue(True)
diff --git a/code/pytorchvideo/tests/test_accelerator_deployment_mobile_cpu_model_conversion.py b/code/pytorchvideo/tests/test_accelerator_deployment_mobile_cpu_model_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d46a932e3e82a503de6c890a06a849053420858
--- /dev/null
+++ b/code/pytorchvideo/tests/test_accelerator_deployment_mobile_cpu_model_conversion.py
@@ -0,0 +1,91 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import unittest
+from collections import OrderedDict
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (
+ convert_to_deployable_form,
+)
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+from pytorchvideo.models.accelerator.mobile_cpu.residual_blocks import (
+ X3dBottleneckBlock,
+)
+
+
+TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
+if TORCH_VERSION >= (1, 11):
+ from torch.ao.quantization import DeQuantStub, QuantStub
+else:
+ from torch.quantization import DeQuantStub, QuantStub
+
+
+class TestDeploymentModelConversion(unittest.TestCase):
+ def test_X3dBottleneckBlock_model_conversion(self):
+ # Input tensor
+ input_blob_size = (1, 3, 4, 6, 6)
+ input_tensor = torch.randn(input_blob_size)
+
+ # Helper class to emulate mix of efficient block and non efficient block
+ class _quant_wrapper(nn.Module):
+ # A common config where user model is wrapped by QuantStub/DequantStub
+ def __init__(self):
+ super().__init__()
+ self.quant = QuantStub() # Non efficient block
+ # X3dBottleneckBlock is efficient block consists of multiple efficient blocks
+ self.model = X3dBottleneckBlock(
+ 3,
+ 12,
+ 3,
+ )
+ self.dequant = DeQuantStub() # Non efficient block
+
+ def forward(self, x):
+ x = self.quant(x)
+ x = self.model(x)
+ x = self.dequant(x)
+ return x
+
+ x3d_block_model_ref = _quant_wrapper()
+
+ # Get ref output
+ x3d_block_model_ref.eval()
+ out_ref = x3d_block_model_ref(input_tensor)
+ # Convert into deployment mode
+ x3d_block_model_converted = convert_to_deployable_form(
+ x3d_block_model_ref, input_tensor
+ )
+ out = x3d_block_model_converted(input_tensor)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out_ref - out)))
+ rel_err = torch.abs((out_ref - out) / out_ref)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_X3dBottleneckBlock_model_conversion: "
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
+ # Check all sub-modules converted
+ for iter_module in x3d_block_model_converted.modules():
+ if isinstance(iter_module, EfficientBlockBase) and (
+ hasattr(iter_module, "convert_flag")
+ ):
+ self.assertTrue(iter_module.convert_flag)
+ # Check all hooks removed
+ for iter_module in x3d_block_model_ref.modules():
+ assert iter_module._forward_hooks == OrderedDict(), (
+ f"{iter_module} in x3d_block_model_ref has non-empty _forward_hooks "
+ f"{iter_module._forward_hooks}"
+ )
+ for iter_module in x3d_block_model_converted.modules():
+ assert iter_module._forward_hooks == OrderedDict(), (
+ f"{iter_module} in x3d_block_model_converted has non-empty _forward_hooks "
+ f"{iter_module._forward_hooks}"
+ )
diff --git a/code/pytorchvideo/tests/test_accelerator_deployment_model_transmuter.py b/code/pytorchvideo/tests/test_accelerator_deployment_model_transmuter.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fa824c067080996c66b7f20fab72d4d08c2c0bf
--- /dev/null
+++ b/code/pytorchvideo/tests/test_accelerator_deployment_model_transmuter.py
@@ -0,0 +1,87 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import unittest
+from copy import deepcopy
+
+# Registers mobile_cpu transmuter functions
+import pytorchvideo.accelerator.deployment.mobile_cpu.transmuter # noqa: F401
+import torch
+import torch.nn as nn
+from pytorchvideo.accelerator.deployment.common.model_transmuter import transmute_model
+from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (
+ convert_to_deployable_form,
+)
+from pytorchvideo.accelerator.efficient_blocks.efficient_block_base import (
+ EfficientBlockBase,
+)
+
+
+class TestModelTransmuter(unittest.TestCase):
+ def test_mobile_cpu_transmuter(self):
+ # Input tensor
+ input_blob_size = (1, 3, 2, 6, 6)
+ input_tensor = torch.randn(input_blob_size)
+
+ # Helper class to emulate user input model
+ class _residual_block(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.stem0 = nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
+ self.stem1 = nn.Conv3d(3, 3, kernel_size=(5, 1, 1), padding=(2, 0, 0))
+ self.pw = nn.Conv3d(3, 6, kernel_size=1)
+ self.relu = nn.ReLU()
+ self.dw = nn.Conv3d(6, 6, kernel_size=3, padding=1, groups=6)
+ self.relu1 = nn.ReLU()
+ self.pwl = nn.Conv3d(6, 3, kernel_size=1)
+ self.relu2 = nn.ReLU()
+
+ def forward(self, x):
+ out = self.stem0(x)
+ out = self.stem1(out)
+ out = self.pw(out)
+ out = self.relu(out)
+ out = self.dw(out)
+ out = self.relu1(out)
+ out = self.pwl(out)
+ return self.relu2(out + x)
+
+ user_model_ref = _residual_block()
+
+ user_model_ref.eval()
+ out_ref = user_model_ref(input_tensor)
+
+ user_model_efficient = deepcopy(user_model_ref)
+ transmute_model(
+ user_model_efficient,
+ target_device="mobile_cpu",
+ )
+ logging.info(f"after convert_model {user_model_efficient}")
+ # Check whether blocks has been replaced by efficientBlock
+ assert isinstance(user_model_efficient.pw, EfficientBlockBase), (
+ f"user_model_efficient.pw {user_model_efficient.pw.__class__.__name__} "
+ "is not converted!"
+ )
+ assert isinstance(user_model_efficient.dw, EfficientBlockBase), (
+ f"user_model_efficient.dw {user_model_efficient.dw.__class__.__name__} "
+ "is not converted!"
+ )
+ assert isinstance(user_model_efficient.pwl, EfficientBlockBase), (
+ f"user_model_efficient.pwl {user_model_efficient.pwl.__class__.__name__} "
+ "is not converted!"
+ )
+ user_model_efficient_converted = convert_to_deployable_form(
+ user_model_efficient, input_tensor
+ )
+ out = user_model_efficient_converted(input_tensor)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out_ref - out)))
+ rel_err = torch.abs((out_ref - out) / out_ref)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_mobile_cpu_transmuter: "
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
diff --git a/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_activation_attention.py b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_activation_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f6e4a6341d28df5406cd131c8090cb7654cddcf
--- /dev/null
+++ b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_activation_attention.py
@@ -0,0 +1,55 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import unittest
+from copy import deepcopy
+
+import torch
+from pytorchvideo.layers.accelerator.mobile_cpu.activation_functions import (
+ supported_act_functions,
+)
+from pytorchvideo.layers.accelerator.mobile_cpu.attention import SqueezeExcitation
+
+
+class TestActivationAttentionEquivalency(unittest.TestCase):
+ def test_activation_equivalency(self):
+ # Input tensor
+ input_tensor = torch.randn(1, 3, 4, 6, 6)
+ for iter_activation_name in supported_act_functions:
+ act_func_ref = supported_act_functions[iter_activation_name]()
+ act_func_convert = deepcopy(act_func_ref)
+ act_func_convert.convert()
+ # Get output of both activations
+ out0 = act_func_ref(input_tensor)
+ out1 = act_func_convert(input_tensor)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out0 - out1)))
+
+ logging.info(
+ f"test_activation_equivalency: {iter_activation_name} max_err {max_err}"
+ )
+ self.assertTrue(max_err < 1e-3)
+
+ def test_squeeze_excite_equivalency(self):
+ # Input tensor
+ input_tensor = torch.randn(1, 16, 4, 6, 6)
+ # Instantiate ref and convert se modules.
+ se_ref = SqueezeExcitation(16, num_channels_reduced=2, is_3d=True)
+ se_ref.eval()
+ se_convert = deepcopy(se_ref)
+ se_convert.convert((1, 16, 4, 6, 6))
+ # Get output of both activations
+ out0 = se_ref(input_tensor)
+ out1 = se_convert(input_tensor)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out0 - out1)))
+ rel_err = torch.abs((out0 - out1) / out0)
+ max_rel_err = float(torch.max(rel_err))
+
+ logging.info(
+ (
+ "test_squeeze_excite_equivalency: "
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
diff --git a/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_conv3d.py b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_conv3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ebbcb9332de5e5428e806afef77efe49c20ac89
--- /dev/null
+++ b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_conv3d.py
@@ -0,0 +1,144 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import unittest
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.accelerator.mobile_cpu.convolutions import (
+ Conv3d3x1x1BnAct,
+ Conv3d3x3x3DwBnAct,
+ Conv3d5x1x1BnAct,
+ Conv3dPwBnAct,
+)
+
+
+class TestConv3dBlockEquivalency(unittest.TestCase):
+ def test_Conv3dPwBnAct_equivalency(self):
+ # Input tensor
+ input_tensor = torch.randn(1, 3, 4, 6, 6)
+ # A conv block
+ l0 = Conv3dPwBnAct(3, 12)
+ l1 = Conv3dPwBnAct(
+ 12, 3, bias=True, activation="identity"
+ ) # Skip relu to avoid NaN for rel error
+ seq0 = nn.Sequential(l0, l1)
+ seq0.eval()
+ out0 = seq0(input_tensor)
+ # Replicate the conv block
+ l0_1 = deepcopy(l0)
+ l1_1 = deepcopy(l1)
+ # Convert into deployment mode
+ l0_1.convert((1, 3, 4, 6, 6)) # Input tensor size is (1,3,4,6,6)
+ l1_1.convert((1, 12, 4, 6, 6)) # Input tensor size is (1,12,4,6,6)
+ seq1 = nn.Sequential(l0_1, l1_1)
+ out1 = seq1(input_tensor)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out0 - out1)))
+ rel_err = torch.abs((out0 - out1) / out0)
+ max_rel_err = float(torch.max(rel_err))
+
+ logging.info(
+ (
+ "test_Conv3dPwBnAct_equivalency: "
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
+
+ def test_Conv3d3x3x3DwBnAct_equivalency(self):
+ # Input tensor
+ input_tensor = torch.randn(1, 3, 4, 6, 6)
+ # A conv block
+ l0 = Conv3dPwBnAct(3, 12)
+ l1 = Conv3d3x3x3DwBnAct(12)
+ l2 = Conv3dPwBnAct(
+ 12, 3, bias=True, activation="identity"
+ ) # Skip relu to avoid NaN for relative error
+ seq0 = nn.Sequential(l0, l1, l2)
+ seq0.eval()
+ out0 = seq0(input_tensor)
+ # Replicate the conv block
+ l0_1 = deepcopy(l0)
+ l1_1 = deepcopy(l1)
+ l2_1 = deepcopy(l2)
+ # Convert into deployment mode
+ l0_1.convert((1, 3, 4, 6, 6)) # Input tensor size is (1,3,4,6,6)
+ l1_1.convert((1, 12, 4, 6, 6)) # Input tensor size is (1,12,4,6,6)
+ l2_1.convert((1, 12, 4, 6, 6)) # Input tensor size is (1,12,4,6,6)
+ seq1 = nn.Sequential(l0_1, l1_1, l2_1)
+ out1 = seq1(input_tensor)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out0 - out1)))
+ rel_err = torch.abs((out0 - out1) / out0)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_Conv3d3x3x3DwBnAct_equivalency: "
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
+
+ def test_Conv3d3x1x1BnAct_equivalency(self):
+ for input_temporal in range(3):
+ input_size = (1, 3, input_temporal + 1, 6, 6)
+ # Input tensor
+ input_tensor = torch.randn(input_size)
+ # A conv block
+ l0 = Conv3d3x1x1BnAct(3, 6)
+ l0.eval()
+ out0 = l0(input_tensor)
+ # Replicate the conv block
+ l0_1 = deepcopy(l0)
+ # Convert into deployment mode
+ l0_1.convert(input_size) # Input tensor size is (1,3,4,6,6)
+ out1 = l0_1(input_tensor)
+ # Check output size
+ assert (
+ out0.size() == out1.size()
+ ), f"Sizes of out0 {out0.size()} and out1 {out1.size()} are different."
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out0 - out1)))
+ rel_err = torch.abs((out0 - out1) / out0)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_Conv3d3x1x1BnAct_equivalency: "
+ f"input tensor size: {input_size}"
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
+
+ def test_Conv3d5x1x1BnAct_equivalency(self):
+ for input_temporal in range(5):
+ input_size = (1, 3, input_temporal + 1, 6, 6)
+ # Input tensor
+ input_tensor = torch.randn(input_size)
+ # A conv block
+ l0 = Conv3d5x1x1BnAct(3, 6)
+ l0.eval()
+ out0 = l0(input_tensor)
+ # Replicate the conv block
+ l0_1 = deepcopy(l0)
+ # Convert into deployment mode
+ l0_1.convert(input_size) # Input tensor size is (1,3,4,6,6)
+ out1 = l0_1(input_tensor)
+ # Check output size
+ assert (
+ out0.size() == out1.size()
+ ), f"Sizes of out0 {out0.size()} and out1 {out1.size()} are different."
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out0 - out1)))
+ rel_err = torch.abs((out0 - out1) / out0)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_Conv3d5x1x1BnAct_equivalency: "
+ f"input tensor size: {input_size}"
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
diff --git a/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_head_layer.py b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_head_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..79fd2f3d7f0ff4bd3b7620f0752c95a4ead2b63d
--- /dev/null
+++ b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_head_layer.py
@@ -0,0 +1,81 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import unittest
+from copy import deepcopy
+
+import torch
+from pytorchvideo.layers.accelerator.mobile_cpu.fully_connected import FullyConnected
+from pytorchvideo.layers.accelerator.mobile_cpu.pool import (
+ AdaptiveAvgPool2d,
+ AdaptiveAvgPool2dOutSize1,
+ AdaptiveAvgPool3d,
+ AdaptiveAvgPool3dOutSize1,
+)
+
+
+class TestHeadLayerEquivalency(unittest.TestCase):
+ def test_head_layer_equivalency(self):
+ for input_dim in (4, 5): # 4 for BCHW, 5 for BCTHW
+ input_tensor_size = (1, 3, 4, 6, 6) if input_dim == 5 else (1, 3, 6, 6)
+ input_tensor = torch.randn(input_tensor_size)
+ # Build up common head layer: pool + linear
+ if input_dim == 5:
+ pool_efficient_block_ref = AdaptiveAvgPool3d(1)
+ pool_efficient_block_1 = AdaptiveAvgPool3d(1)
+ pool_efficient_block_2 = AdaptiveAvgPool3dOutSize1()
+
+ else:
+ pool_efficient_block_ref = AdaptiveAvgPool2d(1)
+ pool_efficient_block_1 = AdaptiveAvgPool2d(1)
+ pool_efficient_block_2 = AdaptiveAvgPool2dOutSize1()
+ pool_efficient_block_1.convert()
+ pool_efficient_block_2.convert(input_tensor_size)
+ linear_ref = FullyConnected(3, 8)
+ linear_1 = deepcopy(linear_ref)
+ linear_1.convert()
+
+ ref_out = pool_efficient_block_ref(input_tensor)
+ if input_dim == 5:
+ ref_out = ref_out.permute((0, 2, 3, 4, 1))
+ else:
+ ref_out = ref_out.permute((0, 2, 3, 1))
+ ref_out = linear_ref(ref_out)
+
+ head_out_1 = pool_efficient_block_1(input_tensor)
+ if input_dim == 5:
+ head_out_1 = head_out_1.permute((0, 2, 3, 4, 1))
+ else:
+ head_out_1 = head_out_1.permute((0, 2, 3, 1))
+ head_out_1 = linear_1(head_out_1)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(ref_out - head_out_1)))
+ rel_err = torch.abs((ref_out - head_out_1) / ref_out)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_head_layer_equivalency: AdaptiveAvgPool + Linear"
+ f"input tensor size: {input_tensor_size}"
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
+
+ head_out_2 = pool_efficient_block_2(input_tensor)
+ if input_dim == 5:
+ head_out_2 = head_out_2.permute((0, 2, 3, 4, 1))
+ else:
+ head_out_2 = head_out_2.permute((0, 2, 3, 1))
+ head_out_2 = linear_1(head_out_2)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(ref_out - head_out_2)))
+ rel_err = torch.abs((ref_out - head_out_2) / ref_out)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_head_layer_equivalency: AdaptiveAvgPoolOutSize1 + Linear"
+ f"input tensor size: {input_tensor_size}"
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
diff --git a/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_residual_block.py b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_residual_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1a6eca3efbb0653c58135293d9c8001ec2b64db
--- /dev/null
+++ b/code/pytorchvideo/tests/test_accelerator_efficient_blocks_mobile_cpu_residual_block.py
@@ -0,0 +1,56 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import logging
+import unittest
+from copy import deepcopy
+
+import torch
+from pytorchvideo.models.accelerator.mobile_cpu.residual_blocks import (
+ X3dBottleneckBlock,
+)
+
+
+class TestConv3dBlockEquivalency(unittest.TestCase):
+ def test_X3dBottleneckBlock_equivalency(self):
+ # Input tensor
+ input_blob_size = (1, 3, 4, 6, 6)
+ input_tensor = torch.randn(input_blob_size)
+ for use_residual in (True, False):
+ for spatial_stride in (1, 2):
+ for se_ratio in (0, 0.5):
+ for act_func_0 in ("relu", "swish", "hswish", "identity"):
+ for act_func_1 in ("relu", "swish", "hswish", "identity"):
+ for act_func_2 in ("relu", "swish", "hswish", "identity"):
+ act_func_tuple = (act_func_0, act_func_1, act_func_2)
+ # X3dBottleneckBlock
+ x3d_block_ref = X3dBottleneckBlock(
+ 3,
+ 16,
+ 3,
+ use_residual=use_residual,
+ spatial_stride=spatial_stride,
+ se_ratio=se_ratio,
+ act_functions=act_func_tuple,
+ )
+ x3d_block = deepcopy(x3d_block_ref)
+ # Get ref output
+ x3d_block_ref.eval()
+ out_ref = x3d_block_ref(input_tensor)
+ # Convert into deployment mode
+ x3d_block.convert(input_blob_size)
+ out = x3d_block(input_tensor)
+ # Check arithmetic equivalency
+ max_err = float(torch.max(torch.abs(out_ref - out)))
+ rel_err = torch.abs((out_ref - out) / out_ref)
+ max_rel_err = float(torch.max(rel_err))
+ logging.info(
+ (
+ "test_X3dBottleneckBlock_equivalency: "
+ f"current setting: use_residual {use_residual}, "
+ f"spatial_stride {spatial_stride}, "
+ f"se_ratio {se_ratio}, "
+ f"act_func_tuple {act_func_tuple}, "
+ f"max_err {max_err}, max_rel_err {max_rel_err}"
+ )
+ )
+ self.assertTrue(max_err < 1e-3)
diff --git a/code/pytorchvideo/tests/test_accelerator_models_efficient_x3d.py b/code/pytorchvideo/tests/test_accelerator_models_efficient_x3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..dffca3d5e7a7b56fe0eafe2b9a8654b744f7bdf5
--- /dev/null
+++ b/code/pytorchvideo/tests/test_accelerator_models_efficient_x3d.py
@@ -0,0 +1,102 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import os
+import unittest
+
+import torch
+from pytorchvideo.models.accelerator.mobile_cpu.efficient_x3d import create_x3d
+
+
+class TestEfficientX3d(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_x3d(self):
+ """
+ To test different versions, set the (expansion, clip_length, crop_size) to:
+ X3D-XS: ("XS", 4, 160)
+ X3D-S: ("S", 13, 160)
+ X3D-M: ("M", 16, 224)
+ X3D-L: ("L", 16, 312)
+ """
+ for (expansion, input_clip_length, input_crop_size,) in [
+ ("XS", 4, 160),
+ ]:
+ model = create_x3d(expansion=expansion)
+
+ # Test forwarding.
+ for tensor in TestEfficientX3d._get_inputs(
+ input_clip_length, input_crop_size
+ ):
+ if tensor.shape[1] != 3:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ output_shape = out.shape
+ output_shape_gt = (tensor.shape[0], 400)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_load_hubconf(self):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ for (input_clip_length, input_crop_size, model_name) in [
+ (4, 160, "efficient_x3d_xs"),
+ (13, 160, "efficient_x3d_s"),
+ ]:
+ model = torch.hub.load(
+ repo_or_dir=path,
+ source="local",
+ model=model_name,
+ pretrained=False,
+ )
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ for tensor in TestEfficientX3d._get_inputs(
+ input_clip_length, input_crop_size
+ ):
+ if tensor.shape[1] != 3:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ output_shape = out.shape
+ output_shape_gt = (tensor.shape[0], 400)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ @staticmethod
+ def _get_inputs(clip_length: int = 4, crop_size: int = 160) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ (1, 3, clip_length, crop_size, crop_size),
+ (2, 3, clip_length, crop_size, crop_size),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_data_ava_dataset.py b/code/pytorchvideo/tests/test_data_ava_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..685339e5f2b1d732af7dc9d04234cfe10c908e15
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_ava_dataset.py
@@ -0,0 +1,240 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import contextlib
+import pathlib
+import random
+import tempfile
+import unittest
+
+import torch
+from pytorchvideo.data import Ava
+from pytorchvideo.data.clip_sampling import make_clip_sampler
+from utils import temp_frame_video
+
+
+AVA_FPS = 30
+
+
+@contextlib.contextmanager
+def temp_ava_dataset_2_videos():
+ frame_names = [f"{str(i)}.png" for i in range(90)]
+ # Create csv containing 2 test frame videos.
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as frames_file:
+ frames_file.write("original_vido_id video_id frame_id path labels\n".encode())
+ # Frame video 1
+ with temp_frame_video(frame_names) as (frame_1_video_dir, data_1):
+ for i, frame_name in enumerate(frame_names):
+ original_video_id_1 = str(frame_1_video_dir)
+ video_id = "1"
+ frame_id = str(i)
+ path = pathlib.Path(frame_1_video_dir) / frame_name
+ label = "0"
+ frames_file.write(
+ f"{original_video_id_1} {video_id} {frame_id} {path} {label}\n".encode()
+ )
+
+ # Frame video 2
+ with temp_frame_video(frame_names, height=5, width=5) as (
+ frame_2_video_dir,
+ data_2,
+ ):
+ for i, frame_name in enumerate(frame_names):
+ original_video_id_2 = str(frame_2_video_dir)
+ video_id = "2"
+ frame_id = str(i)
+ path = pathlib.Path(frame_2_video_dir) / frame_name
+ label = "1"
+ frames_file.write(
+ f"{original_video_id_2} {video_id} {frame_id} {path} {label}\n".encode()
+ )
+
+ frames_file.close()
+ yield frames_file.name, data_1, data_2, original_video_id_1, original_video_id_2
+
+
+def get_random_bbox():
+ bb_list = [round(random.random(), 3) for x in range(4)]
+ converted_list = [str(element) for element in bb_list]
+ return bb_list, ",".join(converted_list)
+
+
+class TestAvaDataset(unittest.TestCase):
+ def test_multiple_videos(self):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as data_file:
+ with temp_ava_dataset_2_videos() as (
+ frame_paths_file,
+ video_1,
+ video_2,
+ video_1_name,
+ video_2_name,
+ ):
+ # add bounding boxes
+ # video 1
+ bb_1_a, bb_1_a_string = get_random_bbox()
+ action_1_a, iou_1_a = 1, 0.85
+ bb_1_b, bb_1_b_string = get_random_bbox()
+ action_1_b, iou_1_b = 2, 0.4
+
+ data_file.write(
+ (
+ f"{video_1_name},902,{bb_1_a_string},"
+ + f"{str(action_1_a)},{str(iou_1_a)}\n"
+ ).encode()
+ )
+ data_file.write(
+ (
+ f"{video_1_name},902,{bb_1_b_string},"
+ + f"{str(action_1_b)},{str(iou_1_b)}\n"
+ ).encode()
+ )
+ # video 2
+ bb_2_a, bb_2_a_string = get_random_bbox()
+ action_2_a, iou_2_a = 3, 0.95
+ bb_2_b, bb_2_b_string = get_random_bbox()
+ action_2_b, iou_2_b = 4, 0.9
+
+ data_file.write(
+ (
+ f"{video_2_name},902,{bb_2_a_string},"
+ + f"{str(action_2_a)},{str(iou_2_a)}\n"
+ ).encode()
+ )
+ data_file.write(
+ (
+ f"{video_2_name},902,{bb_2_b_string},"
+ + f"{str(action_2_b)},{str(iou_2_b)}\n"
+ ).encode()
+ )
+
+ data_file.close()
+
+ dataset = Ava(
+ frame_paths_file=frame_paths_file,
+ frame_labels_file=data_file.name,
+ clip_sampler=make_clip_sampler("random", 1.0),
+ )
+
+ # All videos are of the form cthw and fps is 30
+ # Clip is samples at time step = 2 secs in video
+ sample_1 = next(dataset)
+ self.assertTrue(sample_1["video"].equal(video_1[:, 45:75, :, :]))
+ self.assertTrue(
+ torch.tensor(sample_1["boxes"]).equal(
+ torch.tensor([bb_1_a, bb_1_b])
+ )
+ )
+ self.assertTrue(
+ torch.tensor(sample_1["labels"]).equal(
+ torch.tensor([[action_1_a], [action_1_b]])
+ )
+ )
+ sample_2 = next(dataset)
+ self.assertTrue(sample_2["video"].equal(video_2[:, 45:75, :, :]))
+ self.assertTrue(
+ torch.tensor(sample_2["boxes"]).equal(
+ torch.tensor([bb_2_a, bb_2_b])
+ )
+ )
+ self.assertTrue(
+ torch.tensor(sample_2["labels"]).equal(
+ torch.tensor([[action_2_a], [action_2_b]])
+ )
+ )
+
+ def test_multiple_videos_with_label_map(self):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as label_map_file:
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as data_file:
+ with temp_ava_dataset_2_videos() as (
+ frame_paths_file,
+ video_1,
+ video_2,
+ video_1_name,
+ video_2_name,
+ ):
+ # Create labelmap file
+ label_map = """item {
+ name: "bend/bow (at the waist)"
+ id: 1
+}
+item {
+ name: "crouch/kneel"
+ id: 3
+}
+item {
+ name: "dance"
+ id: 4
+}"""
+ label_map_file.write(label_map.encode())
+ label_map_file.close()
+
+ # add bounding boxes
+ # video 1
+ bb_1_a, bb_1_a_string = get_random_bbox()
+ action_1_a, iou_1_a = 1, 0.85
+ bb_1_b, bb_1_b_string = get_random_bbox()
+ action_1_b, iou_1_b = 2, 0.4
+
+ data_file.write(
+ (
+ f"{video_1_name},902,{bb_1_a_string},"
+ + f"{str(action_1_a)},{str(iou_1_a)}\n"
+ ).encode()
+ )
+ data_file.write(
+ (
+ f"{video_1_name},902,{bb_1_b_string},"
+ + f"{str(action_1_b)},{str(iou_1_b)}\n"
+ ).encode()
+ )
+ # video 2
+ bb_2_a, bb_2_a_string = get_random_bbox()
+ action_2_a, iou_2_a = 3, 0.95
+ bb_2_b, bb_2_b_string = get_random_bbox()
+ action_2_b, iou_2_b = 4, 0.9
+
+ data_file.write(
+ (
+ f"{video_2_name},902,{bb_2_a_string},"
+ + f"{str(action_2_a)},{str(iou_2_a)}\n"
+ ).encode()
+ )
+ data_file.write(
+ (
+ f"{video_2_name},902,{bb_2_b_string},"
+ + f"{str(action_2_b)},{str(iou_2_b)}\n"
+ ).encode()
+ )
+
+ data_file.close()
+
+ dataset = Ava(
+ frame_paths_file=frame_paths_file,
+ frame_labels_file=data_file.name,
+ clip_sampler=make_clip_sampler("random", 1.0),
+ label_map_file=label_map_file.name,
+ )
+
+ # All videos are of the form cthw and fps is 30
+ # Clip is samples at time step = 2 secs in video
+ sample_1 = next(dataset)
+ self.assertTrue(sample_1["video"].equal(video_1[:, 45:75, :, :]))
+ self.assertTrue(
+ torch.tensor(sample_1["boxes"]).equal(torch.tensor([bb_1_a]))
+ )
+ self.assertTrue(
+ torch.tensor(sample_1["labels"]).equal(
+ torch.tensor([[action_1_a]])
+ )
+ )
+ sample_2 = next(dataset)
+ self.assertTrue(sample_2["video"].equal(video_2[:, 45:75, :, :]))
+ self.assertTrue(
+ torch.tensor(sample_2["boxes"]).equal(
+ torch.tensor([bb_2_a, bb_2_b])
+ )
+ )
+ self.assertTrue(
+ torch.tensor(sample_2["labels"]).equal(
+ torch.tensor([[action_2_a], [action_2_b]])
+ )
+ )
diff --git a/code/pytorchvideo/tests/test_data_charades_dataset.py b/code/pytorchvideo/tests/test_data_charades_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2727ce9b9fd1b31e1d5c057a0d0f1d9258f7f4e
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_charades_dataset.py
@@ -0,0 +1,102 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import contextlib
+import pathlib
+import tempfile
+import unittest
+
+from pytorchvideo.data import Charades
+from pytorchvideo.data.clip_sampling import make_clip_sampler
+from torch.utils.data import SequentialSampler
+from utils import temp_frame_video, temp_frame_video_dataset
+
+
+@contextlib.contextmanager
+def temp_charades_dataset():
+
+ # Create csv containing 2 test frame videos.
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as f:
+ f.write("original_vido_id video_id frame_id path labels\n".encode())
+
+ with temp_frame_video_dataset() as (video_frames, _):
+ for (
+ original_video_id,
+ video_id,
+ frame_id,
+ path,
+ label,
+ _,
+ ) in video_frames:
+ f.write(
+ f"{original_video_id} {video_id} {frame_id} {path} {label}\n".encode()
+ )
+
+ f.close()
+ yield f.name, video_frames[0][-1], video_frames[1][-1]
+
+
+class TestCharadesDataset(unittest.TestCase):
+ def test_single_clip_per_video_works(self):
+ with temp_charades_dataset() as (filename, video_1, video_2):
+ clip_sampler = make_clip_sampler(
+ "uniform", 0.1 # Total duration of 3 frames at 30fps is 0.1 seconds.
+ )
+ dataset = Charades(
+ filename, clip_sampler=clip_sampler, video_sampler=SequentialSampler
+ )
+ expected = [([[0], [0], [0]], video_1), ([[1], [1], [1]], video_2)]
+ for sample, expected_sample in zip(dataset, expected):
+ self.assertEqual(sample["label"], expected_sample[0])
+ self.assertTrue(sample["video"].equal(expected_sample[1]))
+
+ def test_multiple_clips_per_video_works(self):
+ with temp_charades_dataset() as (filename, video_1, video_2):
+ clip_sampler = make_clip_sampler(
+ "uniform", 0.033 # Expects each clip to have 1 frame each.
+ )
+ dataset = Charades(
+ filename, clip_sampler=clip_sampler, video_sampler=SequentialSampler
+ )
+
+ expected = [
+ ([[0]], video_1[:, 0:1]),
+ ([[0]], video_1[:, 1:2]),
+ ([[0]], video_1[:, 2:3]),
+ ([[1]], video_2[:, 0:1]),
+ ([[1]], video_2[:, 1:2]),
+ ([[1]], video_2[:, 2:3]),
+ ]
+ for sample, expected_sample in zip(dataset, expected):
+ self.assertEqual(sample["label"], expected_sample[0])
+ self.assertTrue(sample["video"].equal(expected_sample[1]))
+
+ def test_multiple_labels_per_frame(self):
+ frame_names = [f"{str(i)}.png" for i in range(3)]
+
+ # Create csv containing a test frame videos.
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as f:
+ f.write("original_vido_id video_id frame_id path labels\n".encode())
+ with temp_frame_video(frame_names) as (frame_1_video_dir, data_1):
+ for i, frame_name in enumerate(frame_names):
+ original_video_id = str(frame_1_video_dir)
+ video_id = "1"
+ frame_id = str(i)
+ path = pathlib.Path(frame_1_video_dir) / frame_name
+ label = "0,100"
+ f.write(
+ f"{original_video_id} {video_id} {frame_id} {path} {label}\n".encode()
+ )
+
+ f.close()
+
+ clip_sampler = make_clip_sampler(
+ "random",
+ 0.1, # Total duration of 3 frames at 30fps is 0.1 seconds.
+ )
+ dataset = Charades(
+ f.name, clip_sampler=clip_sampler, video_sampler=SequentialSampler
+ )
+
+ sample = next(dataset)
+ self.assertEqual(sample["label"], [[0, 100], [0, 100], [0, 100]])
+ self.assertTrue(sample["video"].equal(data_1))
diff --git a/code/pytorchvideo/tests/test_data_dataset_manifest_utils.py b/code/pytorchvideo/tests/test_data_dataset_manifest_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4019e18a283d34e9fb81beb126697432f2be1e8d
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_dataset_manifest_utils.py
@@ -0,0 +1,148 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+import unittest.mock
+
+from pytorchvideo.data.dataset_manifest_utils import (
+ EncodedVideoInfo,
+ VideoDataset,
+ VideoFrameInfo,
+ VideoInfo,
+)
+from utils import get_flat_video_frames, MOCK_VIDEO_IDS, MOCK_VIDEO_INFOS
+
+
+class TestDatasetManifestUtils(unittest.TestCase):
+ def test_VideoFrameInfo(self):
+ video_frame_info = VideoFrameInfo(
+ # This is a key-mapping as the underlying
+ # annotation files are of these string columns
+ **{
+ "video_id": "P01_012",
+ "location": "c:/",
+ "frame_file_stem": "P01_012_",
+ "frame_string_length": "20",
+ "min_frame_number": "0",
+ "max_frame_number": "22",
+ "file_extension": "png",
+ }
+ )
+ self.assertEqual(video_frame_info.video_id, "P01_012")
+ self.assertEqual(video_frame_info.location, "c:/")
+ self.assertEqual(video_frame_info.frame_file_stem, "P01_012_")
+ self.assertEqual(video_frame_info.frame_string_length, 20)
+ self.assertEqual(video_frame_info.min_frame_number, 0)
+ self.assertEqual(video_frame_info.max_frame_number, 22)
+ self.assertEqual(video_frame_info.file_extension, "png")
+
+ def test_EncodedVideoInfo(self):
+ encoded_video_info = EncodedVideoInfo(
+ # This is a key-mapping as the underlying epic-kitchen
+ # annotation files are of these string columns
+ **{"video_id": "P01_12", "file_path": "c:/P01_12.mp4"}
+ )
+ self.assertEqual(encoded_video_info.video_id, "P01_12")
+ self.assertEqual(encoded_video_info.file_path, "c:/P01_12.mp4")
+
+ def test_VideoInfo(self):
+ video_info = VideoInfo(
+ # This is a key-mapping as the underlying epic-kitchen
+ # annotation files are of these string columns
+ **{
+ "video_id": "P01_01",
+ "resolution": "1000x200",
+ "duration": "123.45",
+ "fps": "59.9",
+ }
+ )
+ self.assertEqual(video_info.video_id, "P01_01")
+ self.assertEqual(video_info.resolution, "1000x200")
+ self.assertEqual(video_info.duration, 123.45)
+ self.assertEqual(video_info.fps, 59.9)
+
+ def test_frame_number_to_filepath(self):
+ file_names_vid4 = VideoDataset._frame_number_to_filepaths(
+ MOCK_VIDEO_IDS[3],
+ get_flat_video_frames("testdirectory", "jpg"),
+ MOCK_VIDEO_INFOS,
+ )
+ file_path = file_names_vid4[100]
+ self.assertEqual(
+ file_path, f"testdirectory/{MOCK_VIDEO_IDS[3]}/frame_0000000101.jpg"
+ )
+ with self.assertRaises(IndexError):
+ file_path = file_names_vid4[10000]
+ file_path = file_names_vid4[-1]
+ self.assertEqual(
+ file_path, f"testdirectory/{MOCK_VIDEO_IDS[3]}/frame_0000001530.jpg"
+ )
+
+ file_names_vid2 = VideoDataset._frame_number_to_filepaths(
+ MOCK_VIDEO_IDS[1],
+ get_flat_video_frames("testdirectory2", "png"),
+ MOCK_VIDEO_INFOS,
+ )
+ file_path = file_names_vid2[0]
+ self.assertEqual(
+ file_path, f"testdirectory2/{MOCK_VIDEO_IDS[1]}/frame_0000000002.png"
+ )
+ file_path = file_names_vid2[2999]
+ self.assertEqual(
+ file_path, f"testdirectory2/{MOCK_VIDEO_IDS[1]}/frame_0000003001.png"
+ )
+ with self.assertRaises(IndexError):
+ file_path = file_names_vid2[3000]
+
+ def test_remove_video_info_missing_or_incomplete_videos(self):
+ video_infos_a = MOCK_VIDEO_INFOS.copy()
+ video_frames_a = get_flat_video_frames("testdirectory2", "jpg")
+ video_frames_a_copy = video_frames_a.copy()
+
+ # No-Op
+ VideoDataset._remove_video_info_missing_or_incomplete_videos(
+ video_frames_a, video_infos_a
+ )
+
+ self.assertEqual(len(video_infos_a), len(MOCK_VIDEO_INFOS))
+ for video_id in video_infos_a:
+ self.assertEqual(video_infos_a[video_id], MOCK_VIDEO_INFOS[video_id])
+
+ self.assertEqual(len(video_frames_a), len(video_frames_a_copy))
+ for video_id in video_frames_a:
+ self.assertEqual(video_frames_a[video_id], video_frames_a_copy[video_id])
+
+ video_infos_b = MOCK_VIDEO_INFOS.copy()
+ video_frames_b = video_frames_a_copy.copy()
+
+ # Unmatched video info, should be removed
+ video_infos_b["P07_001"] = VideoInfo(
+ video_id="P07_001", resolution="720x1280", duration=17.001, fps=30
+ )
+
+ # Unmatched video frame entry, should be removed
+ video_frames_b["P07_002"]: VideoFrameInfo(
+ min_frame_number=1, max_frame_number=1530, frame_string_length=8
+ )
+
+ # Video info that defines approximately 6000 frames with 600 present from frame manifest
+ # Should be dropped
+ video_frames_b["P08_001"]: VideoFrameInfo(
+ min_frame_number=1, max_frame_number=600, frame_string_length=8
+ )
+
+ video_infos_b["P08_001"] = VideoInfo(
+ video_id="P08_001", resolution="720x1280", duration=100, fps=60
+ )
+
+ VideoDataset._remove_video_info_missing_or_incomplete_videos(
+ video_frames_b, video_infos_b
+ )
+
+ # All newly added fields should be removed
+ self.assertEqual(len(video_infos_b), len(MOCK_VIDEO_INFOS))
+ for video_id in video_infos_b:
+ self.assertEqual(video_infos_b[video_id], MOCK_VIDEO_INFOS[video_id])
+
+ self.assertEqual(len(video_frames_b), len(video_frames_a_copy))
+ for video_id in video_frames_b:
+ self.assertEqual(video_frames_b[video_id], video_frames_a_copy[video_id])
diff --git a/code/pytorchvideo/tests/test_data_domsev_dataset.py b/code/pytorchvideo/tests/test_data_domsev_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..12ddaff808d222c837b5affafeda1cf867a36fe8
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_domsev_dataset.py
@@ -0,0 +1,230 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import tempfile
+import unittest
+import unittest.mock
+from contextlib import ExitStack
+from pathlib import Path
+
+import torch
+from parameterized import parameterized
+from pytorchvideo.data.dataset_manifest_utils import VideoClipInfo, VideoDatasetType
+from pytorchvideo.data.domsev import (
+ _get_overlap_for_time_range_pair,
+ _seconds_to_frame_index,
+ DomsevVideoDataset,
+ LabelData,
+)
+from pytorchvideo.data.utils import save_dataclass_objs_to_headered_csv
+from utils import (
+ get_encoded_video_infos,
+ get_flat_video_frames,
+ MOCK_VIDEO_IDS,
+ MOCK_VIDEO_INFOS,
+)
+
+
+class TestDomsevVideoDataset(unittest.TestCase):
+
+ # video_id: str
+ # start_time: float # Start time of the label, in seconds
+ # stop_time: float # Stop time of the label, in seconds
+ # start_frame: int # 0-indexed ID of the start frame (inclusive)
+ # stop_frame: int # 0-index ID of the stop frame (inclusive)
+ # label_id: int
+ # label_name: str
+ LABELS_DATA = {
+ MOCK_VIDEO_IDS[0]: [
+ LabelData(
+ MOCK_VIDEO_IDS[0],
+ 0.0,
+ 6.0,
+ 1,
+ 181,
+ 1,
+ "walking",
+ ),
+ LabelData(
+ MOCK_VIDEO_IDS[0],
+ 6.0333333,
+ 10.0,
+ 182,
+ 301,
+ 2,
+ "running",
+ ),
+ LabelData(
+ MOCK_VIDEO_IDS[0],
+ 10.033333,
+ 20.0,
+ 302,
+ 601,
+ 0,
+ "none",
+ ),
+ ],
+ MOCK_VIDEO_IDS[1]: [
+ LabelData(
+ MOCK_VIDEO_IDS[1],
+ 3.0,
+ 5.0,
+ 181,
+ 301,
+ 7,
+ "cooking",
+ ),
+ ],
+ MOCK_VIDEO_IDS[2]: [
+ LabelData(
+ MOCK_VIDEO_IDS[2],
+ 100.0,
+ 200.0,
+ 3001,
+ 6001,
+ 9,
+ "observing",
+ ),
+ ],
+ MOCK_VIDEO_IDS[3]: [
+ LabelData(
+ MOCK_VIDEO_IDS[3],
+ 10.0,
+ 20.0,
+ 901,
+ 1801,
+ 5,
+ "driving",
+ ),
+ ],
+ }
+
+ def setUp(self):
+ pass
+
+ def test_seconds_to_frame_index(self):
+ self.assertEqual(_seconds_to_frame_index(10.56, 1, zero_indexed=True), 10)
+ self.assertEqual(_seconds_to_frame_index(10.56, 1, zero_indexed=False), 11)
+
+ self.assertEqual(_seconds_to_frame_index(9.99, 1, zero_indexed=True), 9)
+ self.assertEqual(_seconds_to_frame_index(9.99, 1, zero_indexed=False), 10)
+
+ self.assertEqual(_seconds_to_frame_index(1.01, 10, zero_indexed=True), 10)
+ self.assertEqual(_seconds_to_frame_index(1.01, 10, zero_indexed=False), 11)
+
+ def test_get_overlap_for_time_range_pair(self):
+ self.assertEqual(_get_overlap_for_time_range_pair(0, 1, 0.1, 0.2), (0.1, 0.2))
+ self.assertEqual(_get_overlap_for_time_range_pair(0.1, 0.2, 0, 1), (0.1, 0.2))
+ self.assertEqual(_get_overlap_for_time_range_pair(0, 1, 0.9, 1.1), (0.9, 1.0))
+ self.assertEqual(_get_overlap_for_time_range_pair(0, 0.2, 0.1, 1), (0.1, 0.2))
+
+ @parameterized.expand([(VideoDatasetType.Frame,), (VideoDatasetType.EncodedVideo,)])
+ def test__len__(self, dataset_type):
+ with tempfile.TemporaryDirectory(prefix=f"{TestDomsevVideoDataset}") as tempdir:
+ tempdir = Path(tempdir)
+
+ video_info_file = tempdir / "test_video_info.csv"
+ save_dataclass_objs_to_headered_csv(
+ list(MOCK_VIDEO_INFOS.values()), video_info_file
+ )
+ label_file = tempdir / "activity_video_info.csv"
+ labels = []
+ for label_list in self.LABELS_DATA.values():
+ for label_data in label_list:
+ labels.append(label_data)
+ save_dataclass_objs_to_headered_csv(labels, label_file)
+
+ video_data_manifest_file_path = (
+ tempdir / "video_data_manifest_file_path.json"
+ )
+ with ExitStack() as stack:
+ if dataset_type == VideoDatasetType.Frame:
+ video_data_dict = get_flat_video_frames(tempdir, "jpg")
+ elif dataset_type == VideoDatasetType.EncodedVideo:
+ video_data_dict = get_encoded_video_infos(tempdir, stack)
+
+ save_dataclass_objs_to_headered_csv(
+ list(video_data_dict.values()), video_data_manifest_file_path
+ )
+ video_ids = list(self.LABELS_DATA)
+ dataset = DomsevVideoDataset(
+ video_data_manifest_file_path=str(video_data_manifest_file_path),
+ video_info_file_path=str(video_info_file),
+ labels_file_path=str(label_file),
+ dataset_type=dataset_type,
+ clip_sampler=lambda x, y: [
+ VideoClipInfo(video_ids[i // 2], i * 2.0, i * 2.0 + 0.9)
+ for i in range(0, 7)
+ ],
+ )
+
+ self.assertEqual(len(dataset._videos), 4)
+ total_labels = [
+ label_data
+ for video_labels in list(dataset._labels_per_video.values())
+ for label_data in video_labels
+ ]
+ self.assertEqual(len(total_labels), 6)
+ self.assertEqual(len(dataset), 7) # Num clips
+
+ @parameterized.expand([(VideoDatasetType.Frame,), (VideoDatasetType.EncodedVideo,)])
+ def test__getitem__(self, dataset_type):
+ with tempfile.TemporaryDirectory(prefix=f"{TestDomsevVideoDataset}") as tempdir:
+ tempdir = Path(tempdir)
+
+ video_info_file = tempdir / "test_video_info.csv"
+ save_dataclass_objs_to_headered_csv(
+ list(MOCK_VIDEO_INFOS.values()), video_info_file
+ )
+ label_file = tempdir / "activity_video_info.csv"
+ labels = []
+ for label_list in self.LABELS_DATA.values():
+ for label_data in label_list:
+ labels.append(label_data)
+ save_dataclass_objs_to_headered_csv(labels, label_file)
+
+ video_data_manifest_file_path = (
+ tempdir / "video_data_manifest_file_path.json"
+ )
+ with ExitStack() as stack:
+ if dataset_type == VideoDatasetType.Frame:
+ video_data_dict = get_flat_video_frames(tempdir, "jpg")
+ elif dataset_type == VideoDatasetType.EncodedVideo:
+ video_data_dict = get_encoded_video_infos(tempdir, stack)
+
+ save_dataclass_objs_to_headered_csv(
+ list(video_data_dict.values()), video_data_manifest_file_path
+ )
+ video_ids = list(self.LABELS_DATA)
+ dataset = DomsevVideoDataset(
+ video_data_manifest_file_path=str(video_data_manifest_file_path),
+ video_info_file_path=str(video_info_file),
+ labels_file_path=str(label_file),
+ dataset_type=dataset_type,
+ clip_sampler=lambda x, y: [
+ VideoClipInfo(video_ids[i // 2], i * 2.0, i * 2.0 + 0.9)
+ for i in range(0, 7)
+ ],
+ )
+
+ get_clip_string = (
+ "pytorchvideo.data.frame_video.FrameVideo.get_clip"
+ if dataset_type == VideoDatasetType.Frame
+ else "pytorchvideo.data.encoded_video.EncodedVideo.get_clip"
+ )
+ with unittest.mock.patch(
+ get_clip_string,
+ return_value=({"video": torch.rand(3, 5, 10, 20), "audio": []}),
+ ) as _:
+ clip_1 = dataset.__getitem__(1)
+ for i, a in enumerate(clip_1["labels"]):
+ self.assertEqual(a, self.LABELS_DATA[video_ids[0]][i])
+ self.assertEqual(clip_1["start_time"], 2.0)
+ self.assertEqual(clip_1["stop_time"], 2.9)
+ self.assertEqual(clip_1["video_id"], MOCK_VIDEO_IDS[0])
+
+ clip_2 = dataset.__getitem__(2)
+ for i, a in enumerate(clip_2["labels"]):
+ self.assertEqual(a, self.LABELS_DATA[video_ids[1]][i])
+ self.assertEqual(clip_2["start_time"], 4.0)
+ self.assertEqual(clip_2["stop_time"], 4.9)
+ self.assertEqual(clip_2["video_id"], MOCK_VIDEO_IDS[1])
diff --git a/code/pytorchvideo/tests/test_data_encoded_video.py b/code/pytorchvideo/tests/test_data_encoded_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eaa74525841e89e146cd974f6b544811d9d02c1
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_encoded_video.py
@@ -0,0 +1,145 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import tempfile
+import unittest
+
+import pytest
+from pytorchvideo.data.encoded_video import EncodedVideo
+from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
+from utils import temp_encoded_video, temp_encoded_video_with_audio
+
+
+class TestEncodedVideo(unittest.TestCase):
+ # Clip sampling is end time exclusive so we need to add _EPS to sample
+ # all the frames of a video.
+ _EPS = 1e-9
+
+ def test_video_works(self):
+ num_frames = 11
+ fps = 5
+ with temp_encoded_video(num_frames=num_frames, fps=fps) as (file_name, data):
+ test_video = EncodedVideo.from_path(file_name)
+ self.assertAlmostEqual(test_video.duration, num_frames / fps)
+
+ # All frames (0 - test_video.duration seconds)
+ clip = test_video.get_clip(0, test_video.duration + self._EPS)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertTrue(frames.equal(data))
+ self.assertEqual(audio_samples, None)
+
+ # Half frames
+ clip = test_video.get_clip(0, test_video.duration / 2)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertTrue(frames.equal(data[:, : round(num_frames / 2)]))
+ self.assertEqual(audio_samples, None)
+
+ # No frames
+ clip = test_video.get_clip(test_video.duration + 1, test_video.duration + 3)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertEqual(frames, None)
+ self.assertEqual(audio_samples, None)
+ test_video.close()
+
+ def test_video_with_shorter_audio_works(self):
+ num_audio_samples = 8000
+ num_frames = 5
+ fps = 5
+ audio_rate = 8000
+ with temp_encoded_video_with_audio(
+ num_frames=num_frames,
+ fps=fps,
+ num_audio_samples=num_audio_samples,
+ audio_rate=audio_rate,
+ ) as (file_name, video_data, audio_data):
+ test_video = EncodedVideo.from_path(file_name)
+
+ # Duration is max of both streams, therefore, the video duration will be expected.
+ self.assertEqual(test_video.duration, num_frames / fps)
+
+ # All audio (0 - 1 seconds)
+ clip = test_video.get_clip(0, test_video.duration + self._EPS)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertTrue(frames.equal(video_data))
+ self.assertTrue(audio_samples.equal(audio_data))
+
+ # Half frames
+ clip = test_video.get_clip(0, test_video.duration / 2)
+ frames, audio_samples = clip["video"], clip["audio"]
+
+ self.assertTrue(frames.equal(video_data[:, : num_frames // 2]))
+ self.assertTrue(audio_samples.equal(audio_data))
+
+ test_video.close()
+
+ def test_video_with_longer_audio_works(self):
+ audio_rate = 10000
+ fps = 5
+ num_frames = 5
+ num_audio_samples = 40000
+ with temp_encoded_video_with_audio(
+ num_frames=num_frames,
+ fps=fps,
+ num_audio_samples=num_audio_samples,
+ audio_rate=audio_rate,
+ ) as (file_name, video_data, audio_data):
+ test_video = EncodedVideo.from_path(file_name)
+
+ # All audio
+ clip = test_video.get_clip(0, test_video.duration + self._EPS)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertTrue(frames.equal(video_data))
+ self.assertTrue(audio_samples.equal(audio_data))
+
+ # No frames (3 - 5 seconds)
+ clip = test_video.get_clip(test_video.duration + 1, test_video.duration + 2)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertEqual(frames, None)
+ self.assertEqual(audio_samples, None)
+
+ test_video.close()
+
+ def test_decode_audio_is_false(self):
+ audio_rate = 10000
+ fps = 5
+ num_frames = 5
+ num_audio_samples = 40000
+ with temp_encoded_video_with_audio(
+ num_frames=num_frames,
+ fps=fps,
+ num_audio_samples=num_audio_samples,
+ audio_rate=audio_rate,
+ ) as (file_name, video_data, audio_data):
+ test_video = EncodedVideo.from_path(file_name, decode_audio=False)
+
+ # All audio
+ clip = test_video.get_clip(0, test_video.duration + self._EPS)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertTrue(frames.equal(video_data))
+ self.assertEqual(audio_samples, None)
+
+ test_video.close()
+
+ def test_file_api(self):
+ num_frames = 11
+ fps = 5
+ with temp_encoded_video(num_frames=num_frames, fps=fps) as (file_name, data):
+ with open(file_name, "rb") as f:
+ test_video = EncodedVideoPyAV(f)
+
+ self.assertAlmostEqual(test_video.duration, num_frames / fps)
+ clip = test_video.get_clip(0, test_video.duration + self._EPS)
+ frames, audio_samples = clip["video"], clip["audio"]
+ self.assertTrue(frames.equal(data))
+ self.assertEqual(audio_samples, None)
+
+ def test_open_video_failure(self):
+ with pytest.raises(FileNotFoundError):
+ test_video = EncodedVideo.from_path("non_existent_file.txt")
+ test_video.close()
+
+ def test_decode_video_failure(self):
+ with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
+ f.write(b"This is not an mp4 file")
+ with pytest.raises(RuntimeError):
+ test_video = EncodedVideo.from_path(f.name)
+ test_video.close()
diff --git a/code/pytorchvideo/tests/test_data_epic_kitchen_dataset.py b/code/pytorchvideo/tests/test_data_epic_kitchen_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4100c522c9945091c89259f114c2ffa8a9fb9604
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_epic_kitchen_dataset.py
@@ -0,0 +1,258 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import tempfile
+import unittest
+import unittest.mock
+from contextlib import ExitStack
+from pathlib import Path
+
+import torch
+from parameterized import parameterized
+from pytorchvideo.data.dataset_manifest_utils import VideoClipInfo, VideoDatasetType
+from pytorchvideo.data.epic_kitchen import ActionData, EpicKitchenDataset
+from pytorchvideo.data.utils import save_dataclass_objs_to_headered_csv
+from utils import (
+ get_encoded_video_infos,
+ get_flat_video_frames,
+ MOCK_VIDEO_IDS,
+ MOCK_VIDEO_INFOS,
+)
+
+
+class TestEpicKitchenDataset(unittest.TestCase):
+
+ ACTIONS_DATAS = {
+ MOCK_VIDEO_IDS[0]: [
+ ActionData(
+ "P01",
+ "P01_01",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "close door",
+ "00:00:05.00",
+ "00:00:07.00",
+ 418,
+ 569,
+ "close",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "close fridge",
+ "00:01:1.91",
+ "01:00:5.33",
+ 1314,
+ 1399,
+ "close",
+ 3,
+ "fridge",
+ 10,
+ "['fridge']",
+ "[10]",
+ ),
+ ],
+ MOCK_VIDEO_IDS[1]: [
+ ActionData(
+ "P02",
+ "P02_002",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ )
+ ],
+ MOCK_VIDEO_IDS[2]: [
+ ActionData(
+ "P02",
+ "P02_005",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ )
+ ],
+ MOCK_VIDEO_IDS[3]: [
+ ActionData(
+ "P07",
+ "P07_002",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ )
+ ],
+ }
+
+ def test_ActionData(self):
+
+ action = ActionData(
+ # This is a key-mapping as the underlying epic-kitchen
+ # annotation files are of these string columns
+ **{
+ "participant_id": "P07",
+ "video_id": "P07_002",
+ "narration": "turn on light",
+ "start_timestamp": "00:00:04.00",
+ "stop_timestamp": "00:00:06.50",
+ "start_frame": "262",
+ "stop_frame": "370",
+ "verb": "turn-on",
+ "verb_class": "12",
+ "noun": "light",
+ "noun_class": "113",
+ "all_nouns": "['light', 'finger', 'wall']",
+ "all_noun_classes": "[113, 1232, 1]",
+ }
+ )
+ self.assertEqual(action.video_id, "P07_002")
+ self.assertEqual(action.start_time, 4.0)
+ self.assertEqual(action.stop_time, 6.5)
+ self.assertEqual(action.verb_class, 12)
+ self.assertEqual(action.noun_class, 113)
+ self.assertEqual(action.all_nouns, ["light", "finger", "wall"])
+
+ self.assertEqual(action.all_noun_classes, [113, 1232, 1])
+
+ @parameterized.expand([(VideoDatasetType.Frame,), (VideoDatasetType.EncodedVideo,)])
+ def test__len__(self, dataset_type):
+ with tempfile.TemporaryDirectory(prefix=f"{TestEpicKitchenDataset}") as tempdir:
+ tempdir = Path(tempdir)
+
+ video_info_file = tempdir / "test_video_info.csv"
+ save_dataclass_objs_to_headered_csv(
+ list(MOCK_VIDEO_INFOS.values()), video_info_file
+ )
+ action_file = tempdir / "action_video_info.csv"
+ actions = []
+ for action_list in self.ACTIONS_DATAS.values():
+ for action in action_list:
+ actions.append(action)
+ save_dataclass_objs_to_headered_csv(actions, action_file)
+
+ video_data_manifest_file_path = (
+ tempdir / "video_data_manifest_file_path.json"
+ )
+ with ExitStack() as stack:
+ if dataset_type == VideoDatasetType.Frame:
+ video_data_dict = get_flat_video_frames(tempdir, "jpg")
+ elif dataset_type == VideoDatasetType.EncodedVideo:
+ video_data_dict = get_encoded_video_infos(tempdir, stack)
+
+ save_dataclass_objs_to_headered_csv(
+ list(video_data_dict.values()), video_data_manifest_file_path
+ )
+
+ dataset = EpicKitchenDataset(
+ video_info_file_path=str(video_info_file),
+ actions_file_path=str(action_file),
+ clip_sampler=lambda x, y: [
+ VideoClipInfo(str(i), i * 2.0, i * 2.0 + 0.9)
+ for i in range(0, 7)
+ ],
+ video_data_manifest_file_path=str(video_data_manifest_file_path),
+ dataset_type=dataset_type,
+ )
+
+ self.assertEqual(len(dataset), 7)
+
+ @parameterized.expand([(VideoDatasetType.Frame,), (VideoDatasetType.EncodedVideo,)])
+ def test__getitem__(self, dataset_type):
+ with tempfile.TemporaryDirectory(prefix=f"{TestEpicKitchenDataset}") as tempdir:
+ tempdir = Path(tempdir)
+
+ video_info_file = tempdir / "test_video_info.csv"
+ save_dataclass_objs_to_headered_csv(
+ list(MOCK_VIDEO_INFOS.values()), video_info_file
+ )
+ action_file = tempdir / "action_video_info.csv"
+ actions = []
+ for action_list in self.ACTIONS_DATAS.values():
+ for action in action_list:
+ actions.append(action)
+ save_dataclass_objs_to_headered_csv(actions, action_file)
+
+ video_data_manifest_file_path = (
+ tempdir / "video_data_manifest_file_path.json"
+ )
+ with ExitStack() as stack:
+ if dataset_type == VideoDatasetType.Frame:
+ video_data_dict = get_flat_video_frames(tempdir, "jpg")
+ elif dataset_type == VideoDatasetType.EncodedVideo:
+ video_data_dict = get_encoded_video_infos(tempdir, stack)
+
+ save_dataclass_objs_to_headered_csv(
+ list(video_data_dict.values()), video_data_manifest_file_path
+ )
+ video_ids = list(self.ACTIONS_DATAS)
+ dataset = EpicKitchenDataset(
+ video_info_file_path=str(video_info_file),
+ actions_file_path=str(action_file),
+ clip_sampler=lambda x, y: [
+ VideoClipInfo(video_ids[i // 2], i * 2.0, i * 2.0 + 0.9)
+ for i in range(0, 7)
+ ],
+ video_data_manifest_file_path=str(video_data_manifest_file_path),
+ dataset_type=dataset_type,
+ )
+
+ get_clip_string = (
+ "pytorchvideo.data.frame_video.FrameVideo.get_clip"
+ if dataset_type == VideoDatasetType.Frame
+ else "pytorchvideo.data.encoded_video.EncodedVideo.get_clip"
+ )
+ with unittest.mock.patch(
+ get_clip_string,
+ return_value=({"video": torch.rand(3, 5, 10, 20), "audio": []}),
+ ) as _:
+ clip_1 = dataset.__getitem__(1)
+ for i, a in enumerate(clip_1["actions"]):
+ self.assertEqual(a, self.ACTIONS_DATAS[video_ids[0]][i])
+ self.assertEqual(clip_1["start_time"], 2.0)
+ self.assertEqual(clip_1["stop_time"], 2.9)
+ self.assertEqual(clip_1["video_id"], MOCK_VIDEO_IDS[0])
+
+ clip_2 = dataset.__getitem__(2)
+ for i, a in enumerate(clip_2["actions"]):
+ self.assertEqual(a, self.ACTIONS_DATAS[video_ids[1]][i])
+ self.assertEqual(clip_2["start_time"], 4.0)
+ self.assertEqual(clip_2["stop_time"], 4.9)
+ self.assertEqual(clip_2["video_id"], MOCK_VIDEO_IDS[1])
diff --git a/code/pytorchvideo/tests/test_data_epic_kitchen_forecasting.py b/code/pytorchvideo/tests/test_data_epic_kitchen_forecasting.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a7983e9e4050b660e3a32d70d551207ead43451
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_epic_kitchen_forecasting.py
@@ -0,0 +1,424 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+import unittest.mock
+
+import torch
+from pytorchvideo.data import EpicKitchenForecasting
+from pytorchvideo.data.epic_kitchen import ActionData
+from pytorchvideo.data.epic_kitchen_forecasting import ClipSampling
+from pytorchvideo.data.frame_video import FrameVideo
+
+
+class TestEpicKitchenForecasting(unittest.TestCase):
+ def test_transform_generator(self):
+ clip = {
+ "start_time": 2.5,
+ "stop_time": 6.5,
+ "video": torch.rand(3, 8, 10, 20),
+ "actions": [
+ ActionData(
+ "P01",
+ "P01_01",
+ "turn off light",
+ "00:00:01.00",
+ "00:00:02.00",
+ 262,
+ 370,
+ "turn-off",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "close door",
+ "00:00:06.00",
+ "00:00:07.00",
+ 418,
+ 569,
+ "close",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "slam door",
+ "00:00:10.00",
+ "00:00:11.00",
+ 408,
+ 509,
+ "slam",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "slam door",
+ "00:00:11.00",
+ "00:00:12.00",
+ 408,
+ 509,
+ "slam",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "slam door",
+ "00:00:12.00",
+ "00:00:13.00",
+ 408,
+ 509,
+ "slam",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ],
+ }
+
+ def additional_transform(clip):
+ clip["video"] = clip["video"].permute(1, 2, 3, 4, 0)
+ return clip
+
+ transform_fn = EpicKitchenForecasting._transform_generator(
+ additional_transform,
+ num_forecast_actions=3,
+ num_input_clips=2,
+ frames_per_clip=4,
+ )
+
+ transformed_clip = transform_fn(clip)
+
+ self.assertEqual(len(transformed_clip["actions"]), 3)
+
+ self.assertEqual(transformed_clip["actions"][0].narration, "slam door")
+ self.assertEqual(transformed_clip["actions"][1].narration, "slam door")
+ self.assertEqual(transformed_clip["actions"][2].narration, "slam door")
+
+ self.assertEqual(transformed_clip["actions"][0].start_time, 10.0)
+ self.assertEqual(transformed_clip["actions"][1].start_time, 11.0)
+ self.assertEqual(transformed_clip["actions"][2].start_time, 12.0)
+
+ self.assertEqual(transformed_clip["start_time"], 2.5)
+ self.assertEqual(transformed_clip["stop_time"], 6.5)
+
+ self.assertEqual(
+ transformed_clip["video"].size(), torch.Size([3, 4, 10, 20, 2])
+ )
+
+ def test_frame_filter_generator(self):
+ # 11 seconds of video at 4 fps
+ input_list = list(range(44))
+
+ # 11 second clip at 4 fps, all frames are included
+ frame_filter_fn = EpicKitchenForecasting._frame_filter_generator(
+ seconds_per_clip=1,
+ num_input_clips=11,
+ frames_per_clip=4,
+ clip_time_stride=1,
+ )
+
+ all_elements = frame_filter_fn(input_list)
+
+ self.assertEqual(all_elements, input_list)
+
+ # 11 second clip at 4 fps, seconds 0-1 and 10-11 are included
+ frame_filter_fn = EpicKitchenForecasting._frame_filter_generator(
+ seconds_per_clip=1,
+ num_input_clips=2,
+ frames_per_clip=4,
+ clip_time_stride=10,
+ )
+ elements_2_clips = frame_filter_fn(input_list)
+ self.assertEqual(len(elements_2_clips), 8)
+ self.assertEqual(elements_2_clips, input_list[:4] + input_list[-4:])
+
+ # 11 second clip at 2 fps, seconds 0-1 and 10-11 are included
+ frame_filter_fn = EpicKitchenForecasting._frame_filter_generator(
+ seconds_per_clip=1,
+ num_input_clips=2,
+ frames_per_clip=2,
+ clip_time_stride=10,
+ )
+ elements_2_clips_2fps = frame_filter_fn(input_list)
+ self.assertEqual(len(elements_2_clips_2fps), 4)
+ self.assertEqual(elements_2_clips_2fps, [0, 2, 40, 42])
+
+ def test_define_clip_structure_generator(self):
+ frame_videos = {
+ "P01_003": FrameVideo.from_frame_paths(
+ [f"root/P01_003/frame_{i}" for i in range(200)], 10
+ ),
+ "P02_004": FrameVideo.from_frame_paths(
+ [f"root/P02_004/frame_{i}" for i in range(300)], 10
+ ),
+ "P11_010": FrameVideo.from_frame_paths(
+ [f"root/P11_010/frame_{i}" for i in range(600)], 30
+ ),
+ }
+ actions = {
+ "P01_003": [
+ ActionData(
+ "P01",
+ "P01_003",
+ "turn off light",
+ "00:00:01.00",
+ "00:00:02.00",
+ 262,
+ 370,
+ "turn-off",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P01",
+ "P01_003",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:05.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P01",
+ "P01_003",
+ "close door",
+ "00:00:06.00",
+ "00:00:07.00",
+ 418,
+ 569,
+ "close",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P01",
+ "P01_003",
+ "slam door",
+ "00:00:10.00",
+ "00:00:11.00",
+ 408,
+ 509,
+ "slam",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ],
+ "P02_004": [
+ ActionData(
+ "P02",
+ "P02_004",
+ "turn off light",
+ "00:00:04.00",
+ "00:00:05.00",
+ 262,
+ 370,
+ "turn-off",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P02",
+ "P02_004",
+ "turn on light",
+ "00:00:05.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P02",
+ "P02_004",
+ "close door",
+ "00:00:08.00",
+ "00:00:09.00",
+ 418,
+ 569,
+ "close",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P02",
+ "P02_004",
+ "slam door",
+ "00:00:10.00",
+ "00:00:11.00",
+ 408,
+ 509,
+ "slam",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ],
+ "P11_010": [
+ ActionData(
+ "P11",
+ "P11_010",
+ "turn off light",
+ "00:00:01.00",
+ "00:00:02.00",
+ 262,
+ 370,
+ "turn-off",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P11",
+ "P11_010",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:05.50",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P11",
+ "P11_010",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P11",
+ "P11_010",
+ "close door",
+ "00:00:06.00",
+ "00:00:07.00",
+ 418,
+ 569,
+ "close",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P11",
+ "P11_010",
+ "slam door",
+ "00:00:10.00",
+ "00:00:11.00",
+ 408,
+ 509,
+ "slam",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ],
+ }
+ random_value = 0.5
+ with unittest.mock.patch("random.random", return_value=random_value) as _:
+ define_clip_structure_fn = (
+ EpicKitchenForecasting._define_clip_structure_generator(
+ seconds_per_clip=1,
+ clip_time_stride=3,
+ num_input_clips=2,
+ num_forecast_actions=2,
+ clip_sampling=ClipSampling.Random,
+ )
+ )
+ clips = define_clip_structure_fn(frame_videos, actions)
+ sorted_clips = sorted(clips, key=lambda c: c.start_time) # For stability
+ for clip in sorted_clips:
+ self.assertEqual(clip.stop_time - clip.start_time, 4.0)
+
+ clips_P01_003 = [c for c in sorted_clips if c.video_id == "P01_003"]
+ self.assertEqual(len(clips_P01_003), 1)
+
+ clips_P01_003[0].start_time == actions["P01_003"][1].stop_time
+
+ clips_P02_004 = [c for c in sorted_clips if c.video_id == "P02_004"]
+ self.assertEqual(len(clips_P02_004), 2)
+ clips_P02_004[0].start_time == actions["P02_004"][0].stop_time
+ clips_P02_004[1].start_time == actions["P02_004"][1].stop_time
+
+ clips_P11_010 = [c for c in sorted_clips if c.video_id == "P11_010"]
+ self.assertEqual(len(clips_P11_010), 1)
+ clips_P11_010[0].start_time == actions["P11_010"][1].stop_time
diff --git a/code/pytorchvideo/tests/test_data_epic_kitchen_recognition.py b/code/pytorchvideo/tests/test_data_epic_kitchen_recognition.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde826ad9c38d3c6f9bd6eea91d65b7620dcc9aa
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_epic_kitchen_recognition.py
@@ -0,0 +1,166 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+import unittest.mock
+
+import torch
+from pytorchvideo.data import EpicKitchenRecognition
+from pytorchvideo.data.epic_kitchen import ActionData
+from pytorchvideo.data.epic_kitchen_recognition import ClipSampling
+from pytorchvideo.data.frame_video import FrameVideo
+
+
+class TestEpicKitchenRecognition(unittest.TestCase):
+ def test_transform_generator(self):
+ clip = {
+ "start_time": 2.5,
+ "stop_time": 6.5,
+ "video": torch.rand(3, 4, 10, 20),
+ "actions": [
+ ActionData(
+ "P01",
+ "P01_01",
+ "turn off light",
+ "00:00:01.00",
+ "00:00:02.00",
+ 262,
+ 370,
+ "turn-off",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "turn on light",
+ "00:00:04.00",
+ "00:00:06.00",
+ 262,
+ 370,
+ "turn-on",
+ 12,
+ "light",
+ 113,
+ "['light']",
+ "[113]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "close door",
+ "00:00:06.00",
+ "00:00:07.00",
+ 418,
+ 569,
+ "close",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ActionData(
+ "P01",
+ "P01_01",
+ "slam door",
+ "00:00:10.00",
+ "00:00:11.00",
+ 408,
+ 509,
+ "slam",
+ 3,
+ "door",
+ 8,
+ "['door']",
+ "[8]",
+ ),
+ ],
+ }
+
+ def additional_transform(clip):
+ clip["video"] = clip["video"].permute(1, 2, 3, 0)
+ return clip
+
+ transform_fn = EpicKitchenRecognition._transform_generator(additional_transform)
+
+ transformed_clip = transform_fn(clip)
+
+ self.assertEqual(len(transformed_clip["actions"]), 2)
+ # Sort for stability
+ sorted_actions = sorted(transformed_clip["actions"], key=lambda a: a.start_time)
+
+ self.assertEqual(sorted_actions[0].narration, "turn on light")
+ self.assertEqual(sorted_actions[1].narration, "close door")
+
+ self.assertEqual(transformed_clip["start_time"], 2.5)
+ self.assertEqual(transformed_clip["stop_time"], 6.5)
+
+ self.assertEqual(transformed_clip["video"].size(), torch.Size([4, 10, 20, 3]))
+
+ def test_frame_filter_generator(self):
+ input_list = list(range(10))
+
+ frame_filter_fn = EpicKitchenRecognition._frame_filter_generator(10)
+ all_elements = frame_filter_fn(input_list)
+ self.assertEqual(all_elements, input_list)
+
+ frame_filter_fn = EpicKitchenRecognition._frame_filter_generator(5)
+ half_elements = frame_filter_fn(input_list)
+ self.assertEqual(len(half_elements), 5)
+ self.assertEqual(half_elements, [i for i in input_list if not i % 2])
+
+ frame_filter_fn = EpicKitchenRecognition._frame_filter_generator(1)
+ half_elements = frame_filter_fn(input_list)
+ self.assertEqual(len(half_elements), 1)
+ self.assertEqual(half_elements[0], 0)
+
+ def test_define_clip_structure_generator(self):
+ seconds_per_clip = 5
+ define_clip_structure_fn = (
+ EpicKitchenRecognition._define_clip_structure_generator(
+ seconds_per_clip=5, clip_sampling=ClipSampling.RandomOffsetUniform
+ )
+ )
+ frame_videos = {
+ "P01_003": FrameVideo.from_frame_paths(
+ [f"root/P01_003/frame_{i}" for i in range(100)], 10
+ ),
+ "P02_004": FrameVideo.from_frame_paths(
+ [f"root/P02_004/frame_{i}" for i in range(300)], 10
+ ),
+ "P11_010": FrameVideo.from_frame_paths(
+ [f"root/P11_010/frame_{i}" for i in range(600)], 30
+ ),
+ }
+ actions = {video_id: [] for video_id in frame_videos}
+ random_value = 0.5
+ with unittest.mock.patch("random.random", return_value=random_value) as _:
+ clips = define_clip_structure_fn(frame_videos, actions)
+ sorted_clips = sorted(clips, key=lambda c: c.start_time) # For stability
+
+ for clip in sorted_clips:
+ self.assertEqual(clip.stop_time - clip.start_time, seconds_per_clip)
+
+ clips_P01_003 = [c for c in sorted_clips if c.video_id == "P01_003"]
+ self.assertEqual(len(clips_P01_003), 1)
+ for i in range(len(clips_P01_003)):
+ self.assertEqual(
+ clips_P01_003[i].start_time, seconds_per_clip * (i + random_value)
+ )
+
+ clips_P02_004 = [c for c in sorted_clips if c.video_id == "P02_004"]
+ self.assertEqual(len(clips_P02_004), 5)
+ for i in range(len(clips_P02_004)):
+ self.assertEqual(
+ clips_P02_004[i].start_time, seconds_per_clip * (i + random_value)
+ )
+
+ clips_P11_010 = [c for c in sorted_clips if c.video_id == "P11_010"]
+ self.assertEqual(len(clips_P11_010), 3)
+ for i in range(len(clips_P11_010)):
+ self.assertEqual(
+ clips_P11_010[i].start_time, seconds_per_clip * (i + random_value)
+ )
diff --git a/code/pytorchvideo/tests/test_data_epic_kitchen_utils.py b/code/pytorchvideo/tests/test_data_epic_kitchen_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0395a245abf3ad1f4da68cf5dc785e78c3ae27
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_epic_kitchen_utils.py
@@ -0,0 +1,190 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+import tempfile
+import unittest
+import unittest.mock
+from pathlib import Path
+
+from pytorchvideo.data.dataset_manifest_utils import EncodedVideoInfo, VideoFrameInfo
+from pytorchvideo.data.epic_kitchen.utils import (
+ build_encoded_manifest_from_nested_directory,
+ build_frame_manifest_from_flat_directory,
+ build_frame_manifest_from_nested_directory,
+)
+
+
+def write_mock_frame_files(video_frames, tempdir, ext):
+ tempdir = Path(tempdir)
+ for _, video_frame_info in video_frames.items():
+ if not os.path.isdir(video_frame_info.location):
+ os.mkdir(video_frame_info.location)
+
+ for frame_num in reversed(
+ range(
+ video_frame_info.min_frame_number, video_frame_info.max_frame_number + 1
+ )
+ ): # Here we reverse the order of the frames we write to test that code
+ # doesn't rely on ls returning frames in order due to
+ # frames being written in order temporally.
+ frame_num_str = str(frame_num)
+ stem = video_frame_info.frame_file_stem
+ frame_num_zeros = "0" * (
+ video_frame_info.frame_string_length - len(frame_num_str) - len(stem)
+ )
+ frame_file_name = f"{stem}{frame_num_zeros}{frame_num_str}.{ext}"
+ with open(f"{video_frame_info.location}/{frame_file_name}", "w") as f:
+ f.write("0")
+
+
+def get_flat_video_frames(directory, file_extension):
+ return {
+ "P02_001": VideoFrameInfo(
+ video_id="P02_001",
+ location=f"{directory}/P02_001",
+ frame_file_stem="frame_",
+ frame_string_length=16,
+ min_frame_number=1,
+ max_frame_number=3000,
+ file_extension=file_extension,
+ ),
+ "P02_002": VideoFrameInfo(
+ video_id="P02_002",
+ location=f"{directory}/P02_002",
+ frame_file_stem="frame_",
+ frame_string_length=16,
+ min_frame_number=2,
+ max_frame_number=3001,
+ file_extension=file_extension,
+ ),
+ "P02_005": VideoFrameInfo(
+ video_id="P02_005",
+ location=f"{directory}/P02_005",
+ frame_file_stem="frame_",
+ frame_string_length=16,
+ min_frame_number=1,
+ max_frame_number=30003,
+ file_extension=file_extension,
+ ),
+ "P07_002": VideoFrameInfo(
+ video_id="P07_002",
+ location=f"{directory}/P07_002",
+ frame_file_stem="frame_",
+ frame_string_length=16,
+ min_frame_number=2,
+ max_frame_number=1530,
+ file_extension=file_extension,
+ ),
+ }
+
+
+def get_nested_video_frames(directory, file_extension):
+ return {
+ "P02_001": VideoFrameInfo(
+ video_id="P02_001",
+ location=f"{directory}/P02",
+ frame_file_stem="P02_001_",
+ frame_string_length=16,
+ min_frame_number=1,
+ max_frame_number=3000,
+ file_extension=file_extension,
+ ),
+ "P02_002": VideoFrameInfo(
+ video_id="P02_002",
+ location=f"{directory}/P02",
+ frame_file_stem="P02_002_",
+ frame_string_length=16,
+ min_frame_number=2,
+ max_frame_number=3001,
+ file_extension=file_extension,
+ ),
+ "P02_005": VideoFrameInfo(
+ video_id="P02_005",
+ location=f"{directory}/P02",
+ frame_file_stem="P02_005_",
+ frame_string_length=16,
+ min_frame_number=1,
+ max_frame_number=30003,
+ file_extension=file_extension,
+ ),
+ "P07_002": VideoFrameInfo(
+ video_id="P07_002",
+ location=f"{directory}/P07",
+ frame_file_stem="P07_002_",
+ frame_string_length=16,
+ min_frame_number=2,
+ max_frame_number=1530,
+ file_extension=file_extension,
+ ),
+ }
+
+
+class TestEpicKitchenUtils(unittest.TestCase):
+ def test_build_frame_manifest_from_flat_directory_sync(self):
+ self.test_build_frame_manifest_from_flat_directory(multithreading=False)
+
+ def test_build_frame_manifest_from_flat_directory(self, multithreading=True):
+ with tempfile.TemporaryDirectory(prefix="TestEpicKitchenUtils") as tempdir:
+ video_frames_expected = get_flat_video_frames(tempdir, "jpg")
+ write_mock_frame_files(video_frames_expected, tempdir, "jpg")
+
+ video_frames = build_frame_manifest_from_flat_directory(
+ tempdir, multithreading
+ )
+
+ self.assertEqual(len(video_frames_expected), len(video_frames))
+ for video_id in video_frames_expected:
+ self.assertEqual(
+ video_frames[video_id], video_frames_expected[video_id]
+ )
+
+ def test_build_frame_manifest_from_nested_directory_sync(self):
+ self.test_build_frame_manifest_from_nested_directory(multithreading=False)
+
+ def test_build_frame_manifest_from_nested_directory(self, multithreading=True):
+ with tempfile.TemporaryDirectory(prefix="TestEpicKitchenUtils") as tempdir:
+ video_frames_expected = get_nested_video_frames(tempdir, "png")
+ write_mock_frame_files(video_frames_expected, tempdir, "png")
+
+ video_frames = build_frame_manifest_from_nested_directory(
+ tempdir, multithreading
+ )
+ self.assertEqual(len(video_frames_expected), len(video_frames))
+ for video_id in video_frames_expected:
+ self.assertEqual(
+ video_frames[video_id], video_frames_expected[video_id]
+ )
+
+ def test_build_encoded_manifest_from_nested_directory(self):
+ file_names = ["P01_01.mp4", "P01_07.mp4", "P23_11.mp4", "P11_00.mp4"]
+ with tempfile.TemporaryDirectory(prefix="TestEpicKitchenUtils") as tempdir:
+
+ for file_name in file_names:
+ participant_path = Path(tempdir) / file_name[:3]
+ if not os.path.isdir(participant_path):
+ os.mkdir(participant_path)
+
+ with open(participant_path / file_name, "w") as f:
+ f.write("0")
+
+ encoded_video_dict = build_encoded_manifest_from_nested_directory(tempdir)
+
+ self.assertEqual(
+ sorted(encoded_video_dict), ["P01_01", "P01_07", "P11_00", "P23_11"]
+ )
+ self.assertEqual(
+ encoded_video_dict["P01_01"],
+ EncodedVideoInfo("P01_01", str(Path(tempdir) / "P01/P01_01.mp4")),
+ )
+ self.assertEqual(
+ encoded_video_dict["P01_07"],
+ EncodedVideoInfo("P01_07", str(Path(tempdir) / "P01/P01_07.mp4")),
+ )
+ self.assertEqual(
+ encoded_video_dict["P11_00"],
+ EncodedVideoInfo("P11_00", str(Path(tempdir) / "P11/P11_00.mp4")),
+ )
+ self.assertEqual(
+ encoded_video_dict["P23_11"],
+ EncodedVideoInfo("P23_11", str(Path(tempdir) / "P23/P23_11.mp4")),
+ )
diff --git a/code/pytorchvideo/tests/test_data_frame_video.py b/code/pytorchvideo/tests/test_data_frame_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ea3f5f2ae19490a33932f85d0ab9da2df46576f
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_frame_video.py
@@ -0,0 +1,50 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import pytest
+from pytorchvideo.data.frame_video import FrameVideo
+from utils import temp_frame_video
+
+
+class TestFrameVideo(unittest.TestCase):
+ def test_frame_video_works(self):
+ frame_names = [f"{str(i)}.png" for i in range(3)]
+ with temp_frame_video(frame_names) as (f_name, data):
+ frame_paths = [f_name / x for x in frame_names]
+ test_video = FrameVideo.from_frame_paths(frame_paths)
+ expected_duration = (
+ 0.1 # Total duration of 3 frames at 30fps is 0.1 seconds.
+ )
+ self.assertEqual(test_video.duration, expected_duration)
+
+ # All frames (0 - 0.1 seconds)
+ clip = test_video.get_clip(0, 0.1)
+ frames, indices = clip["video"], clip["frame_indices"]
+ self.assertTrue(frames.equal(data))
+ self.assertEqual(indices, [0, 1, 2])
+
+ # All frames (0 - 0.1 seconds), filtred to middle frame
+ clip = test_video.get_clip(0, 0.1, lambda lst: lst[1:2])
+ frames, indices = clip["video"], clip["frame_indices"]
+ self.assertTrue(frames.equal(data[:, 1:2]))
+ self.assertEqual(indices, [1])
+
+ # 2 frames (0 - 0.066 seconds)
+ clip = test_video.get_clip(0, 0.066)
+ frames, indices = clip["video"], clip["frame_indices"]
+ self.assertTrue(frames.equal(data[:, :2]))
+ self.assertEqual(indices, [0, 1])
+
+ # No frames (3 - 5 seconds)
+ result = test_video.get_clip(3, 5)
+ self.assertEqual(result, None)
+
+ def test_open_video_failure(self):
+ test_video = FrameVideo.from_frame_paths(["non_existent_file.txt"])
+ with pytest.raises(Exception):
+ test_video.get_clip(0, 0.01) # duration is 1 / 30 because one frame
+
+ def test_empty_frames_failure(self):
+ with pytest.raises(AssertionError):
+ FrameVideo.from_frame_paths([])
diff --git a/code/pytorchvideo/tests/test_data_json_dataset.py b/code/pytorchvideo/tests/test_data_json_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a836e54217b50b9cc87942a068d83811e6ac849e
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_json_dataset.py
@@ -0,0 +1,98 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import contextlib
+import json
+import tempfile
+import unittest
+
+from pytorchvideo.data import json_dataset
+from pytorchvideo.data.clip_sampling import make_clip_sampler
+from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset
+from utils import temp_frame_video_dataset
+
+
+class TestJsonDatasets(unittest.TestCase):
+ def setUp(self):
+ LabeledVideoDataset._MAX_CONSECUTIVE_FAILURES = 1
+
+ def test_recognition_random_clip_sampler(self):
+ total_duration = 0.05
+ with mock_json_annotations() as (annotation_json, labels, duration):
+ clip_sampler = make_clip_sampler("random", total_duration)
+ dataset = json_dataset.clip_recognition_dataset(
+ data_path=annotation_json,
+ clip_sampler=clip_sampler,
+ decode_audio=False,
+ )
+
+ self.assertEqual(dataset.num_videos, 4)
+ self.assertEqual(len(list(iter(dataset))), 4)
+
+ def test_recognition_uniform_clip_sampler(self):
+ total_duration = 0.05
+ with mock_json_annotations() as (annotation_json, labels, duration):
+ clip_sampler = make_clip_sampler("uniform", total_duration)
+ dataset = json_dataset.clip_recognition_dataset(
+ data_path=annotation_json,
+ clip_sampler=clip_sampler,
+ decode_audio=False,
+ )
+
+ self.assertEqual(dataset.num_videos, 4)
+ self.assertEqual(len(list(iter(dataset))), 4)
+
+ def test_video_only_frame_video_dataset(self):
+ total_duration = 2.0
+ with mock_json_annotations() as (annotation_json, labels, duration):
+ clip_sampler = make_clip_sampler("random", total_duration)
+ dataset = json_dataset.video_only_dataset(
+ data_path=annotation_json,
+ clip_sampler=clip_sampler,
+ decode_audio=False,
+ )
+
+ self.assertEqual(dataset.num_videos, 2)
+ self.assertEqual(len(list(iter(dataset))), 2)
+
+
+@contextlib.contextmanager
+def mock_json_annotations():
+ with temp_frame_video_dataset() as (_, videos):
+ label_videos = []
+ json_dict = {}
+ for video in videos:
+ label_videos.append((video[-3], video[-2]))
+ name = str(video[0])
+ json_dict[name] = {
+ "benchmarks": {
+ "forecasting_hands_objects": [
+ {
+ "critical_frame_selection_parent_start_sec": 0.001,
+ "critical_frame_selection_parent_end_sec": 0.012,
+ "taxonomy": {
+ "noun": video[-3],
+ "verb": video[-3],
+ "noun_unsure": False,
+ "verb_unsure": False,
+ },
+ },
+ {
+ "critical_frame_selection_parent_start_sec": 0.01,
+ "critical_frame_selection_parent_end_sec": 0.05,
+ "taxonomy": {
+ "noun": video[-3],
+ "verb": video[-3],
+ "noun_unsure": False,
+ "verb_unsure": False,
+ },
+ },
+ ]
+ }
+ }
+
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="wt") as f:
+ json.dump(json_dict, f)
+ f.close()
+
+ min_duration = min(videos[0][-1], videos[1][-1])
+ yield f.name, label_videos, min_duration
diff --git a/code/pytorchvideo/tests/test_data_labeled_video_dataset.py b/code/pytorchvideo/tests/test_data_labeled_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f899a6f9685e0d2f292fce92f80ac3c5082529f5
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_labeled_video_dataset.py
@@ -0,0 +1,759 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import collections
+import contextlib
+import itertools
+import math
+import multiprocessing
+import os
+import pathlib
+import tempfile
+import unittest
+import unittest.mock
+from fractions import Fraction
+from typing import List, Tuple
+from unittest.mock import Mock, patch
+
+# av import has to be added for `buck test` to work.
+import av # noqa: F401
+import torch
+import torch.distributed as dist
+from parameterized import parameterized
+from pytorchvideo.data import Hmdb51
+from pytorchvideo.data.clip_sampling import make_clip_sampler
+from pytorchvideo.data.labeled_video_dataset import (
+ labeled_video_dataset,
+ LabeledVideoDataset,
+)
+from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
+from pytorchvideo.data.utils import MultiProcessSampler, thwc_to_cthw
+from torch.multiprocessing import Process
+from torch.utils.data import (
+ DataLoader,
+ DistributedSampler,
+ RandomSampler,
+ SequentialSampler,
+ TensorDataset,
+)
+from utils import (
+ create_dummy_video_frames,
+ temp_encoded_video,
+ temp_frame_video_dataset,
+)
+
+
+DECODER_LIST = [("pyav",), ("torchvision",), ("decord",)]
+
+
+class TestLabeledVideoDataset(unittest.TestCase):
+ def setUp(self):
+ # Fail fast for tests
+ LabeledVideoDataset._MAX_CONSECUTIVE_FAILURES = 1
+
+ @parameterized.expand(DECODER_LIST)
+ def test_single_clip_per_video_works(self, decoder):
+ with mock_encoded_video_dataset_file() as (mock_csv, expected, total_duration):
+ clip_sampler = make_clip_sampler("uniform", total_duration)
+ dataset = labeled_video_dataset(
+ data_path=mock_csv,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+ test_dataloader = DataLoader(dataset, batch_size=None, num_workers=2)
+
+ for _ in range(2):
+ actual = [
+ (sample["label"], sample["video"]) for sample in test_dataloader
+ ]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ @parameterized.expand(DECODER_LIST)
+ def test_video_name_with_whitespace_works(self, decoder):
+ num_frames = 10
+ fps = 5
+ with temp_encoded_video(num_frames=num_frames, fps=fps, prefix="pre fix") as (
+ video_file_name,
+ data,
+ ):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
+ f.write(f"{video_file_name} 0\n".encode())
+ f.write(f"{video_file_name} 1\n".encode())
+
+ total_duration = num_frames / fps
+ clip_sampler = make_clip_sampler("uniform", total_duration)
+ labeled_video_paths = LabeledVideoPaths.from_path(f.name)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ expected = [(0, data), (1, data)]
+ for i, sample in enumerate(dataset):
+ self.assertTrue(sample["video"].equal(expected[i][1]))
+ self.assertEqual(sample["label"], expected[i][0])
+
+ @parameterized.expand(DECODER_LIST)
+ def test_random_clip_sampling_works(self, decoder):
+ with mock_encoded_video_dataset_file() as (
+ mock_csv,
+ label_videos,
+ total_duration,
+ ):
+ half_duration = total_duration / 2
+ clip_sampler = make_clip_sampler("random", half_duration)
+ labeled_video_paths = LabeledVideoPaths.from_path(mock_csv)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ expected_labels = [label for label, _ in label_videos]
+ for i, sample in enumerate(dataset):
+ expected_t_shape = 5
+ self.assertEqual(sample["video"].shape[1], expected_t_shape)
+ self.assertEqual(sample["label"], expected_labels[i])
+
+ @parameterized.expand(DECODER_LIST)
+ def test_random_multi_clip_sampling_works(self, decoder):
+ with mock_encoded_video_dataset_file() as (
+ mock_csv,
+ label_videos,
+ total_duration,
+ ):
+ half_duration = total_duration / 2
+ num_clip = 3
+ clip_sampler = make_clip_sampler("random_multi", half_duration, num_clip)
+ labeled_video_paths = LabeledVideoPaths.from_path(mock_csv)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ expected_labels = [label for label, _ in label_videos]
+ for i, sample in enumerate(dataset):
+ expected_t_shape = 5
+ self.assertTrue(isinstance(sample["video"], list))
+ self.assertEqual(len(sample["video"]), num_clip)
+ self.assertEqual(sample["video"][0].shape[1], expected_t_shape)
+ self.assertEqual(sample["video"][-1].shape[1], expected_t_shape)
+ self.assertEqual(sample["label"], expected_labels[i])
+
+ @parameterized.expand(DECODER_LIST)
+ def test_reading_from_directory_structure_hmdb51(self, decoder):
+ # For an unknown reason this import has to be here for `buck test` to work.
+ import torchvision.io as io
+
+ with tempfile.TemporaryDirectory() as root_dir:
+
+ # Create test directory structure with two classes and a video in each.
+ root_dir_name = pathlib.Path(root_dir)
+ action_1 = "running"
+ action_2 = "cleaning_windows"
+
+ videos_root_dir = root_dir_name / "videos"
+ videos_root_dir.mkdir()
+
+ test_class_1 = videos_root_dir / action_1
+ test_class_1.mkdir()
+ data_1 = create_dummy_video_frames(15, 10, 10)
+ test_class_2 = videos_root_dir / action_2
+ test_class_2.mkdir()
+ data_2 = create_dummy_video_frames(20, 15, 15)
+
+ test_splits = root_dir_name / "folds"
+ test_splits.mkdir()
+
+ with tempfile.NamedTemporaryFile(
+ suffix="_u_nm_np1_ba_goo_19.avi", dir=test_class_1
+ ) as f_1, tempfile.NamedTemporaryFile(
+ suffix="_u_nm_np1_fr_med_1.avi", dir=test_class_2
+ ) as f_2:
+ f_1.close()
+ f_2.close()
+
+ # Write lossless video for each class.
+ io.write_video(
+ f_1.name,
+ data_1,
+ fps=30,
+ video_codec="libx264rgb",
+ options={"crf": "0"},
+ )
+ io.write_video(
+ f_2.name,
+ data_2,
+ fps=30,
+ video_codec="libx264rgb",
+ options={"crf": "0"},
+ )
+
+ _, video_name_1 = os.path.split(f_1.name)
+ _, video_name_2 = os.path.split(f_2.name)
+
+ with open(
+ os.path.join(test_splits, action_1 + "_test_split1.txt"), "w"
+ ) as f:
+ f.write(f"{video_name_1} 1\n")
+
+ with open(
+ os.path.join(test_splits, action_2 + "_test_split1.txt"), "w"
+ ) as f:
+ f.write(f"{video_name_2} 1\n")
+
+ clip_sampler = make_clip_sampler("uniform", 3)
+ dataset = Hmdb51(
+ data_path=test_splits,
+ video_path_prefix=root_dir_name / "videos",
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ split_id=1,
+ split_type="train",
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ # Videos are sorted alphabetically so "cleaning windows" (i.e. data_2)
+ # will be first.
+ sample_1 = next(dataset)
+ sample_2 = next(dataset)
+
+ self.assertTrue(sample_1["label"] in [action_1, action_2])
+ if sample_1["label"] == action_2:
+ sample_1, sample_2 = sample_2, sample_1
+
+ self.assertEqual(sample_1["label"], action_1)
+ self.assertEqual(5, len(sample_1["meta_tags"]))
+ self.assertTrue(
+ sample_1["video"].equal(thwc_to_cthw(data_1).to(torch.float32))
+ )
+
+ self.assertEqual(sample_2["label"], action_2)
+ self.assertEqual(5, len(sample_2["meta_tags"]))
+ self.assertTrue(
+ sample_2["video"].equal(thwc_to_cthw(data_2).to(torch.float32))
+ )
+
+ @parameterized.expand(DECODER_LIST)
+ def test_constant_clips_per_video_sampling_works_with_fraction(self, decoder):
+ # Make one video with 15 frames and one with 10 frames, producing 3 clips and 2
+ # clips respectively.
+ num_frames = 10
+ fps = 5
+ with temp_encoded_video(num_frames=int(num_frames * 1.5), fps=fps) as (
+ video_file_name_1,
+ data_1,
+ ):
+ with temp_encoded_video(num_frames=num_frames, fps=fps) as (
+ video_file_name_2,
+ data_2,
+ ):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
+ f.write(f"{video_file_name_1} 0\n".encode())
+ f.write(f"{video_file_name_2} 1\n".encode())
+
+ clip_frames = 2
+ duration_for_frames = Fraction(clip_frames, fps)
+ clip_sampler = make_clip_sampler(
+ "constant_clips_per_video", duration_for_frames, 2
+ )
+ labeled_video_paths = LabeledVideoPaths.from_path(f.name)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ # Dataset has 2 videos. Each video has two evenly spaced clips of size
+ # clip_frames sampled. The first clip of each video will always be
+ # sampled at second 0. The second clip of the video is the next frame
+ # from time: (total_duration - clip_duration) / 2
+ half_frames_1 = math.ceil((data_1.shape[1]) / 2)
+ half_frames_2 = math.ceil((data_2.shape[1]) / 2)
+ expected = [
+ (0, data_1[:, :clip_frames]),
+ (0, data_1[:, half_frames_1 : half_frames_1 + clip_frames]),
+ (1, data_2[:, :clip_frames]),
+ (1, data_2[:, half_frames_2 : half_frames_2 + clip_frames]),
+ ]
+ for i, sample in enumerate(dataset):
+ self.assertTrue(sample["video"].equal(expected[i][1]))
+ self.assertEqual(sample["label"], expected[i][0])
+
+ @parameterized.expand(DECODER_LIST)
+ def test_reading_from_directory_structure(self, decoder):
+ # For an unknown reason this import has to be here for `buck test` to work.
+ import torchvision.io as io
+
+ with tempfile.TemporaryDirectory() as root_dir:
+
+ # Create test directory structure with two classes and a video in each.
+ root_dir_name = pathlib.Path(root_dir)
+ test_class_1 = root_dir_name / "running"
+ test_class_1.mkdir()
+ data_1 = create_dummy_video_frames(15, 10, 10)
+ test_class_2 = root_dir_name / "cleaning windows"
+ test_class_2.mkdir()
+ data_2 = create_dummy_video_frames(20, 15, 15)
+ with tempfile.NamedTemporaryFile(
+ suffix=".mp4", dir=test_class_1
+ ) as f_1, tempfile.NamedTemporaryFile(
+ suffix=".mp4", dir=test_class_2
+ ) as f_2:
+ f_1.close()
+ f_2.close()
+
+ # Write lossless video for each class.
+ io.write_video(
+ f_1.name,
+ data_1,
+ fps=30,
+ video_codec="libx264rgb",
+ options={"crf": "0"},
+ )
+ io.write_video(
+ f_2.name,
+ data_2,
+ fps=30,
+ video_codec="libx264rgb",
+ options={"crf": "0"},
+ )
+
+ clip_sampler = make_clip_sampler("uniform", 3)
+ labeled_video_paths = LabeledVideoPaths.from_path(root_dir)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ # Videos are sorted alphabetically so "cleaning windows" (i.e. data_2)
+ # will be first.
+ sample_1 = next(dataset)
+ self.assertEqual(sample_1["label"], 0)
+ self.assertTrue(
+ sample_1["video"].equal(thwc_to_cthw(data_2).to(torch.float32))
+ )
+
+ sample_2 = next(dataset)
+ self.assertEqual(sample_2["label"], 1)
+ self.assertTrue(
+ sample_2["video"].equal(thwc_to_cthw(data_1).to(torch.float32))
+ )
+
+ def test_frame_video_dataset_works(self):
+ with mock_frame_video_dataset_file() as (mock_csv, expected, total_duration):
+ clip_sampler = make_clip_sampler("uniform", total_duration)
+ dataset = labeled_video_dataset(
+ data_path=mock_csv,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder="frame",
+ )
+
+ test_dataloader = DataLoader(dataset, batch_size=None, num_workers=2)
+
+ for _ in range(2):
+ actual = [
+ (sample["label"], sample["video"]) for sample in test_dataloader
+ ]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ @parameterized.expand(DECODER_LIST)
+ def test_random_video_sampler(self, decoder):
+ with mock_encoded_video_dataset_file() as (mock_csv, expected, total_duration):
+ clip_sampler = make_clip_sampler("uniform", total_duration)
+ dataset = labeled_video_dataset(
+ data_path=mock_csv,
+ clip_sampler=clip_sampler,
+ video_sampler=RandomSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ for _ in range(2):
+ actual = [(sample["label"], sample["video"]) for sample in dataset]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ @parameterized.expand(itertools.product([0, 1, 2], ["pyav", "torchvision"]))
+ def test_random_video_sampler_multiprocessing(self, num_workers, decoder):
+ with mock_encoded_video_dataset_file() as (mock_csv, expected, total_duration):
+ clip_sampler = make_clip_sampler("uniform", total_duration)
+ dataset = labeled_video_dataset(
+ data_path=mock_csv,
+ clip_sampler=clip_sampler,
+ video_sampler=RandomSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+ test_dataloader = DataLoader(
+ dataset, batch_size=None, num_workers=num_workers
+ )
+
+ for _ in range(2):
+ actual = [
+ (sample["label"], sample["video"]) for sample in test_dataloader
+ ]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ @parameterized.expand(DECODER_LIST)
+ def test_sampling_with_multiple_processes(self, decoder):
+ with mock_encoded_video_dataset_file() as (
+ mock_csv,
+ label_videos,
+ total_duration,
+ ):
+ half_duration = total_duration / 2
+ clip_sampler = make_clip_sampler("uniform", half_duration)
+ labeled_video_paths = LabeledVideoPaths.from_path(mock_csv)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ # Split each full video into two clips.
+ expected = []
+ for label, data in label_videos:
+ num_frames = data.shape[0]
+ half_frames = num_frames // 2
+ first_half_data = data[:, :half_frames]
+ second_half_data = data[:, half_frames:]
+ expected.append((label, first_half_data))
+ expected.append((label, second_half_data))
+
+ test_dataloader = DataLoader(dataset, batch_size=None, num_workers=2)
+ actual = [(sample["label"], sample["video"]) for sample in test_dataloader]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ @parameterized.expand(DECODER_LIST)
+ def test_sampling_with_non_divisible_processes_by_videos(self, decoder):
+ with mock_encoded_video_dataset_file() as (
+ mock_csv,
+ label_videos,
+ total_duration,
+ ):
+ half_duration = total_duration / 2
+ clip_sampler = make_clip_sampler("uniform", half_duration)
+ labeled_video_paths = LabeledVideoPaths.from_path(mock_csv)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ # Split each full video into two clips.
+ expected = []
+ for label, data in label_videos:
+ num_frames = data.shape[0]
+ half_frames = num_frames // 2
+ first_half_data = data[:, :half_frames]
+ second_half_data = data[:, half_frames:]
+ expected.append((label, first_half_data))
+ expected.append((label, second_half_data))
+
+ test_dataloader = DataLoader(dataset, batch_size=None, num_workers=4)
+ actual = [(sample["label"], sample["video"]) for sample in test_dataloader]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ @parameterized.expand(DECODER_LIST)
+ def test_sampling_with_more_processes_than_videos(self, decoder):
+ with mock_encoded_video_dataset_file() as (
+ mock_csv,
+ label_videos,
+ total_duration,
+ ):
+ half_duration = total_duration / 2
+ clip_sampler = make_clip_sampler("uniform", half_duration)
+ labeled_video_paths = LabeledVideoPaths.from_path(mock_csv)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ # Split each full video into two clips.
+ expected = []
+ for label, data in label_videos:
+ num_frames = data.shape[0]
+ half_frames = num_frames // 2
+ first_half_data = data[:, :half_frames]
+ second_half_data = data[:, half_frames:]
+ expected.append((label, first_half_data))
+ expected.append((label, second_half_data))
+
+ test_dataloader = DataLoader(dataset, batch_size=None, num_workers=16)
+ actual = [(sample["label"], sample["video"]) for sample in test_dataloader]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ @parameterized.expand(DECODER_LIST)
+ def test_sampling_with_non_divisible_processes_by_clips(self, decoder):
+
+ # Make one video with 15 frames and one with 10 frames, producing 3 clips and 2
+ # clips respectively.
+ num_frames = 10
+ fps = 5
+ with temp_encoded_video(num_frames=int(num_frames * 1.5), fps=fps) as (
+ video_file_name_1,
+ data_1,
+ ):
+ with temp_encoded_video(num_frames=num_frames, fps=fps) as (
+ video_file_name_2,
+ data_2,
+ ):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
+ f.write(f"{video_file_name_1} 0\n".encode())
+ f.write(f"{video_file_name_2} 1\n".encode())
+
+ total_duration = num_frames / fps
+ half_duration = total_duration / 2
+ clip_sampler = make_clip_sampler("uniform", half_duration)
+ labeled_video_paths = LabeledVideoPaths.from_path(f.name)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+
+ half_frames = num_frames // 2
+ expected = {
+ (0, data_1[:, half_frames * 2 :]), # 1/3 clip
+ (0, data_1[:, half_frames : half_frames * 2]), # 2/3 clip
+ (0, data_1[:, :half_frames]), # 3/3/ clip
+ (1, data_2[:, :half_frames]), # First half
+ (1, data_2[:, half_frames:]), # Second half
+ }
+
+ test_dataloader = DataLoader(dataset, batch_size=None, num_workers=2)
+ actual = [
+ (sample["label"], sample["video"]) for sample in test_dataloader
+ ]
+ assert_unordered_list_compare_true(self, expected, actual)
+
+ def test_multi_process_sampler(self):
+ # Test coverage ignores multi-process lines of code so we need to mock out
+ # the multiprocess environment information to test in a single process.
+ with patch("torch.utils.data.get_worker_info") as get_worker_info:
+ get_worker_info.return_value = Mock(id=2, num_workers=3)
+ inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ dataset = TensorDataset(inps, tgts)
+ sampler = iter(MultiProcessSampler(SequentialSampler(dataset)))
+
+ # Sampler indices will be split into 3. The last worker (id=2) will have the
+ # last 3 indices (7, 8, 9).
+ self.assertEqual(list(sampler), [7, 8, 9])
+
+ @parameterized.expand(DECODER_LIST)
+ def test_sampling_with_distributed_sampler(self, decoder):
+
+ # Make one video with 15 frames and one with 10 frames, producing 3 clips and 2
+ # clips respectively.
+ num_frames = 10
+ fps = 5
+ with temp_encoded_video(num_frames=int(num_frames * 1.5), fps=fps) as (
+ video_file_name_1,
+ data_1,
+ ):
+ with temp_encoded_video(num_frames=num_frames, fps=fps) as (
+ video_file_name_2,
+ data_2,
+ ):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
+ f.write(f"{video_file_name_1} 0\n".encode())
+ f.write(f"{video_file_name_2} 1\n".encode())
+
+ total_duration = num_frames / fps
+ half_duration = total_duration / 2
+
+ # Create several processes initialized in a PyTorch distributed process
+ # group so that distributed sampler is setup correctly when dataset is
+ # constructed.
+ num_processes = 2
+ processes = []
+ return_dict = multiprocessing.Manager().dict()
+ for rank in range(num_processes):
+ p = Process(
+ target=run_distributed,
+ args=(
+ rank,
+ num_processes,
+ decoder,
+ half_duration,
+ f.name,
+ return_dict,
+ ),
+ )
+ p.start()
+ processes.append(p)
+
+ for p in processes:
+ p.join()
+
+ # After joining all distributed processes we expect all these label,
+ # video pairs to be returned in random order.
+ half_frames = num_frames // 2
+ expected = {
+ (0, data_1[:, :half_frames]), # 1/3 clip
+ (0, data_1[:, half_frames : half_frames * 2]), # 2/3 clip
+ (0, data_1[:, half_frames * 2 :]), # 3/3 clip
+ (1, data_2[:, :half_frames]), # First half
+ (1, data_2[:, half_frames:]), # Second half
+ }
+
+ epoch_results = collections.defaultdict(list)
+ for v in return_dict.values():
+ for k_2, v_2 in v.items():
+ epoch_results[k_2].extend(v_2)
+
+ assert_unordered_list_compare_true(
+ self, expected, epoch_results["epoch_1"]
+ )
+ assert_unordered_list_compare_true(
+ self, expected, epoch_results["epoch_2"]
+ )
+
+
+def assert_unordered_list_compare_true(
+ self,
+ expected: List[Tuple[int, torch.Tensor]],
+ actual: List[Tuple[int, torch.Tensor]],
+):
+ """
+ Asserts True if all tuple values from expected found in actual and lengths are equal.
+ """
+ expected_str = str([(label, clip.shape) for label, clip in expected])
+ actual = str([(label, clip.shape) for label, clip in actual])
+ failure_str = f"Expected set: {expected_str}\n actual set: {actual}"
+ self.assertTrue(unordered_list_compare, msg=failure_str)
+
+
+def unordered_list_compare(
+ expected: List[Tuple[int, torch.Tensor]], actual: List[Tuple[int, torch.Tensor]]
+):
+ """
+ Returns:
+ True if all tuple values from expected found in actual and lengths are equal.
+ """
+ if len(actual) != len(expected):
+ return False
+
+ for expected_x in expected:
+
+ # Uses torch comparator for Tensor.
+ if not any(
+ actual_x[0] == expected_x[0] and actual_x[1].equal(expected_x[1])
+ for actual_x in actual
+ ):
+ return False
+
+ return True
+
+
+def run_distributed(rank, size, decoder, clip_duration, data_name, return_dict):
+ """
+ This function is run by each distributed process. It samples videos
+ based on the distributed split (determined by the
+ DistributedSampler) and returns the dataset clips in the return_dict.
+ """
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
+ os.environ["MASTER_PORT"] = "29500"
+ dist.init_process_group("gloo", rank=rank, world_size=size)
+ clip_sampler = make_clip_sampler("uniform", clip_duration)
+ labeled_video_paths = LabeledVideoPaths.from_path(data_name)
+ dataset = LabeledVideoDataset(
+ labeled_video_paths,
+ clip_sampler=clip_sampler,
+ video_sampler=DistributedSampler,
+ decode_audio=False,
+ decoder=decoder,
+ )
+ test_dataloader = DataLoader(dataset, batch_size=None, num_workers=1)
+
+ # Run two epochs, simulating use in a training loop
+ dataset.video_sampler.set_epoch(0)
+ epoch_1 = [(sample["label"], sample["video"]) for sample in test_dataloader]
+ dataset.video_sampler.set_epoch(1)
+ epoch_2 = [(sample["label"], sample["video"]) for sample in test_dataloader]
+ return_dict[rank] = {"epoch_1": epoch_1, "epoch_2": epoch_2}
+
+
+@contextlib.contextmanager
+def mock_encoded_video_dataset_file():
+ """
+ Creates a temporary mock encoded video dataset with 4 videos labeled from 0 - 4.
+ Returns a labeled video file which points to this mock encoded video dataset, the
+ ordered label and videos tuples and the video duration in seconds.
+ """
+ num_frames = 10
+ fps = 5
+ with temp_encoded_video(num_frames=num_frames, fps=fps) as (
+ video_file_name_1,
+ data_1,
+ ):
+ with temp_encoded_video(num_frames=num_frames, fps=fps) as (
+ video_file_name_2,
+ data_2,
+ ):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
+ f.write(f"{video_file_name_1} 0\n".encode())
+ f.write(f"{video_file_name_2} 1\n".encode())
+ f.write(f"{video_file_name_1} 2\n".encode())
+ f.write(f"{video_file_name_2} 3\n".encode())
+
+ label_videos = [
+ (0, data_1),
+ (1, data_2),
+ (2, data_1),
+ (3, data_2),
+ ]
+ video_duration = num_frames / fps
+ yield f.name, label_videos, video_duration
+
+
+@contextlib.contextmanager
+def mock_frame_video_dataset_file():
+ """
+ Creates a temporary mock frame video dataset with 4 videos labeled from 0 - 4.
+ Returns a labeled video file which points to this mock frame video dataset, the
+ ordered label and videos tuples and the video duration in seconds.
+ """
+ with temp_frame_video_dataset() as (_, videos):
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
+ f.write(f"{videos[0][0]} 0\n".encode())
+ f.write(f"{videos[1][0]} 1\n".encode())
+ f.write(f"{videos[0][0]} 2\n".encode())
+ f.write(f"{videos[1][0]} 3\n".encode())
+
+ label_videos = [
+ (0, videos[0][-2]),
+ (1, videos[1][-2]),
+ (0, videos[0][-2]),
+ (1, videos[1][-2]),
+ ]
+
+ min_duration = min(videos[0][-1], videos[1][-1])
+ yield f.name, label_videos, min_duration
diff --git a/code/pytorchvideo/tests/test_data_ssv2_dataset.py b/code/pytorchvideo/tests/test_data_ssv2_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e32319d6a356e80a11b3b57677bc289ffc1b370d
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_ssv2_dataset.py
@@ -0,0 +1,96 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import contextlib
+import json
+import pathlib
+import tempfile
+import unittest
+
+from pytorchvideo.data import SSv2
+from pytorchvideo.data.clip_sampling import make_clip_sampler
+from torch.utils.data import SequentialSampler
+from utils import temp_frame_video
+
+
+@contextlib.contextmanager
+def temp_ssv2_dataset():
+ frame_names = [f"{str(i)}.png" for i in range(7)]
+
+ # Create json file for label names.
+ labels = [
+ "Approaching something with your camera",
+ "Attaching something to something",
+ ]
+ label_names = {labels[0]: "0", labels[1]: "1"}
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
+ json.dump(label_names, f)
+ label_name_file = f.name
+
+ # Create csv containing 2 test frame videos.
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as f:
+ f.write("original_vido_id video_id frame_id path labels\n".encode())
+
+ # Frame video 1
+ with temp_frame_video(frame_names) as (frame_1_video_dir, data_1):
+ for i, frame_name in enumerate(frame_names):
+ original_video_id = str(frame_1_video_dir)
+ video_id = "1"
+ frame_id = str(i)
+ path = pathlib.Path(frame_1_video_dir) / frame_name
+ f.write(
+ f"{original_video_id} {video_id} {frame_id} {path} ''\n".encode()
+ )
+
+ # Frame video 2
+ with temp_frame_video(frame_names) as (frame_2_video_dir, data_2):
+ for i, frame_name in enumerate(frame_names):
+ original_video_id = str(frame_2_video_dir)
+ video_id = "2"
+ frame_id = str(i)
+ path = pathlib.Path(frame_2_video_dir) / frame_name
+ f.write(
+ f"{original_video_id} {video_id} {frame_id} {path} ''\n".encode()
+ )
+
+ f.close()
+ video_path_file = f.name
+
+ # Create json file for lable names.
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".json"
+ ) as f:
+ videos = [
+ {"id": str(frame_1_video_dir), "template": labels[0]},
+ {"id": str(frame_2_video_dir), "template": labels[1]},
+ ]
+ json.dump(videos, f)
+ video_label_file = f.name
+
+ yield label_name_file, video_label_file, video_path_file, data_1, data_2
+
+
+class TestSSv2Dataset(unittest.TestCase):
+ def test_single_clip_per_video_works(self):
+ with temp_ssv2_dataset() as (
+ label_name_file,
+ video_label_file,
+ video_path_file,
+ video_1,
+ video_2,
+ ):
+
+ # Put arbitrary duration as ssv2 always needs full video clip.
+ clip_sampler = make_clip_sampler("constant_clips_per_video", 1.0, 1)
+ # Expect taking 2 frames (1-th and 4-th among 7 frames).
+ dataset = SSv2(
+ label_name_file,
+ video_label_file,
+ video_path_file,
+ clip_sampler=clip_sampler,
+ video_sampler=SequentialSampler,
+ frames_per_clip=2,
+ )
+ expected = [(0, video_1), (1, video_2)]
+ for sample, expected_sample in zip(dataset, expected):
+ self.assertEqual(sample["label"], expected_sample[0])
+ self.assertTrue(sample["video"].equal(expected_sample[1][:, (1, 4)]))
diff --git a/code/pytorchvideo/tests/test_data_utils.py b/code/pytorchvideo/tests/test_data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..534e7792953e39e289fb3c471fe96ee50d6ac644
--- /dev/null
+++ b/code/pytorchvideo/tests/test_data_utils.py
@@ -0,0 +1,175 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+import tempfile
+import unittest
+import unittest.mock
+from dataclasses import dataclass
+from pathlib import Path
+
+from pytorchvideo.data.utils import (
+ DataclassFieldCaster,
+ export_video_array,
+ load_dataclass_dict_from_csv,
+ save_dataclass_objs_to_headered_csv,
+)
+from pytorchvideo.data.video import VideoPathHandler
+from utils import temp_encoded_video
+
+
+@dataclass
+class TestDataclass(DataclassFieldCaster):
+ a: str
+ b: int
+ b_plus_1: int = DataclassFieldCaster.complex_initialized_dataclass_field(
+ lambda v: int(v) + 1
+ )
+ c: float
+ d: list
+ e: dict = DataclassFieldCaster.complex_initialized_dataclass_field(lambda v: {v: v})
+
+
+@dataclass
+class TestDataclass2(DataclassFieldCaster):
+ a: str
+ b: int
+
+
+class TestDataUtils(unittest.TestCase):
+ def test_DataclassFieldCaster(self):
+ test_obj = TestDataclass("1", "1", "1", "1", "abc", "k")
+
+ self.assertEqual(test_obj.a, "1")
+ self.assertEqual(type(test_obj.a), str)
+
+ self.assertEqual(test_obj.b, 1)
+ self.assertEqual(type(test_obj.b), int)
+ self.assertEqual(test_obj.b_plus_1, 2)
+
+ self.assertEqual(test_obj.c, 1.0)
+ self.assertEqual(type(test_obj.c), float)
+
+ self.assertEqual(test_obj.d, ["a", "b", "c"])
+ self.assertEqual(type(test_obj.d), list)
+
+ self.assertEqual(test_obj.e, {"k": "k"})
+ self.assertEqual(type(test_obj.e), dict)
+
+ def _export_video_array(
+ self,
+ video_codec="libx264rgb",
+ height=10,
+ width=10,
+ num_frames=10,
+ fps=5,
+ options=None,
+ epsilon=3,
+ ):
+ with temp_encoded_video(
+ num_frames=num_frames, fps=fps, height=height, width=width
+ ) as (video_file_name, data,), tempfile.TemporaryDirectory(
+ prefix="video_stop_gap_test"
+ ) as tempdir:
+ exported_video_path = os.path.join(tempdir, "video.mp4")
+ export_video_array(
+ data,
+ output_path=exported_video_path,
+ rate=fps,
+ video_codec=video_codec,
+ options=options,
+ )
+ vp_handler = VideoPathHandler()
+ video = vp_handler.video_from_path(exported_video_path, decode_audio=False)
+ reloaded_data = video.get_clip(0, video.duration)["video"]
+ self.assertLessEqual((data - reloaded_data).abs().mean(), epsilon)
+
+ def test_export_video_array_mult(self):
+ self._export_video_array(
+ video_codec="libx264rgb",
+ height=10,
+ width=10,
+ num_frames=10,
+ fps=5,
+ options={"crf": "0"},
+ epsilon=1e-6,
+ )
+ self._export_video_array(
+ video_codec="mpeg4", height=10, width=10, num_frames=10, fps=5
+ )
+ self._export_video_array(
+ video_codec="mpeg4", height=480, width=640, num_frames=30, fps=30
+ )
+
+ def test_load_dataclass_dict_from_csv_value_dict(self):
+ dataclass_objs = [
+ TestDataclass2("a", 1),
+ TestDataclass2("b", 2),
+ TestDataclass2("c", 3),
+ TestDataclass2("d", 4),
+ ]
+ with tempfile.TemporaryDirectory(prefix=f"{TestDataUtils}") as tempdir:
+ csv_file_name = Path(tempdir) / "data.csv"
+ save_dataclass_objs_to_headered_csv(dataclass_objs, csv_file_name)
+
+ test_dict = load_dataclass_dict_from_csv(
+ csv_file_name, TestDataclass2, "a", list_per_key=False
+ )
+ self.assertEqual(len(test_dict), 4)
+ self.assertEqual(test_dict["c"].b, 3)
+
+ def test_load_dataclass_dict_from_csv_list_dict(self):
+ dataclass_objs = [
+ TestDataclass2("a", 1),
+ TestDataclass2("a", 2),
+ TestDataclass2("b", 3),
+ TestDataclass2("c", 4),
+ TestDataclass2("c", 4),
+ TestDataclass2("c", 4),
+ ]
+ with tempfile.TemporaryDirectory(prefix=f"{TestDataUtils}") as tempdir:
+ csv_file_name = Path(tempdir) / "data.csv"
+ save_dataclass_objs_to_headered_csv(dataclass_objs, csv_file_name)
+ test_dict = load_dataclass_dict_from_csv(
+ csv_file_name, TestDataclass2, "a", list_per_key=True
+ )
+ self.assertEqual(len(test_dict), 3)
+ self.assertEqual([x.b for x in test_dict["a"]], [1, 2])
+ self.assertEqual([x.b for x in test_dict["b"]], [3])
+ self.assertEqual([x.b for x in test_dict["c"]], [4, 4, 4])
+
+ def test_load_dataclass_dict_from_csv_throws(self):
+ dataclass_objs = [
+ TestDataclass2("a", 1),
+ TestDataclass2("a", 2),
+ TestDataclass2("b", 3),
+ TestDataclass2("c", 4),
+ TestDataclass2("c", 4),
+ TestDataclass2("c", 4),
+ ]
+ with tempfile.TemporaryDirectory(prefix=f"{TestDataUtils}") as tempdir:
+ csv_file_name = Path(tempdir) / "data.csv"
+ save_dataclass_objs_to_headered_csv(dataclass_objs, csv_file_name)
+ self.assertRaises(
+ AssertionError,
+ lambda: load_dataclass_dict_from_csv(
+ csv_file_name, TestDataclass2, "a", list_per_key=False
+ ),
+ )
+
+ def test_save_dataclass_objs_to_headered_csv(self):
+ dataclass_objs = [
+ TestDataclass2("a", 1),
+ TestDataclass2("a", 2),
+ TestDataclass2("b", 3),
+ ]
+
+ with tempfile.TemporaryDirectory(prefix=f"{TestDataUtils}") as tempdir:
+ csv_file_name = Path(tempdir) / "data.csv"
+ save_dataclass_objs_to_headered_csv(dataclass_objs, csv_file_name)
+ with open(csv_file_name) as f:
+ lines = list(f.readlines())
+ self.assertEqual(len(lines), 4)
+ self.assertEqual(lines[0], "a,b\n")
+ self.assertEqual(lines[1], "a,1\n")
+ self.assertEqual(lines[2], "a,2\n")
+ self.assertEqual(lines[3], "b,3\n")
diff --git a/code/pytorchvideo/tests/test_fuse_bn.py b/code/pytorchvideo/tests/test_fuse_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..eefa1a26a84a322b5b8310019d92ddc31dc9c29e
--- /dev/null
+++ b/code/pytorchvideo/tests/test_fuse_bn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import torch
+from pytorchvideo.models.vision_transformers import (
+ create_multiscale_vision_transformers,
+)
+
+
+class TestFuseBN(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_fuse_bn(self):
+ model = create_multiscale_vision_transformers(
+ spatial_size=224,
+ temporal_size=8,
+ norm="batchnorm",
+ embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
+ atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
+ pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
+ pool_kv_stride_adaptive=[1, 8, 8],
+ pool_kvq_kernel=[3, 3, 3],
+ cls_embed_on=False,
+ )
+
+ for blk in model.blocks:
+ blk.norm1 = rand_init_bn(blk.norm1)
+ blk.norm2 = rand_init_bn(blk.norm2)
+ if blk.attn.norm_q:
+ blk.attn.norm_q = rand_init_bn(blk.attn.norm_q)
+ if blk.attn.norm_k:
+ blk.attn.norm_k = rand_init_bn(blk.attn.norm_k)
+ if blk.attn.norm_v:
+ blk.attn.norm_v = rand_init_bn(blk.attn.norm_v)
+
+ model.eval()
+
+ x = torch.randn((4, 3, 8, 224, 224))
+ expected_output = model(x)
+ model.fuse_bn()
+ output = model(x)
+ self.assertTrue(torch.all(torch.isclose(output, expected_output, atol=1e-5)))
+ self.assertTrue(
+ len(
+ [
+ layer
+ for layer in model.modules()
+ if isinstance(layer, (torch.nn.BatchNorm1d, torch.nn.BatchNorm3d))
+ ]
+ )
+ == 0
+ )
+
+
+def rand_init_bn(bn):
+ bn.weight.data.uniform_(0.5, 1.5)
+ bn.bias.data.uniform_(-0.5, 0.5)
+ bn.running_var.data.uniform_(0.5, 1.5)
+ bn.running_mean.data.uniform_(-0.5, 0.5)
+ return bn
diff --git a/code/pytorchvideo/tests/test_layers_attention.py b/code/pytorchvideo/tests/test_layers_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..088a8bbb023468bd656ebc8cfc28d5a236cf258c
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_attention.py
@@ -0,0 +1,191 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers import Mlp, MultiScaleAttention, MultiScaleBlock
+
+
+class TestMLP(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_MultiScaleAttention(self):
+ seq_len = 21
+ c_dim = 10
+ c_dim_out = 20
+
+ # Test MultiScaleAttention without dim expansion; i.e., no dim_out
+ multiscale_attention = MultiScaleAttention(c_dim, num_heads=2)
+ fake_input = torch.rand(8, seq_len, c_dim)
+ input_shape = (2, 2, 5)
+ output, output_shape = multiscale_attention(fake_input, input_shape)
+ self.assertTrue(output.shape, fake_input.shape)
+
+ # Test MultiScaleAttention with dim expansion
+ multiscale_attention = MultiScaleAttention(
+ c_dim, dim_out=c_dim_out, num_heads=2
+ )
+ fake_input = torch.rand(8, seq_len, c_dim)
+ input_shape = (2, 2, 5)
+ output, output_shape = multiscale_attention(fake_input, input_shape)
+ gt_shape_tensor = torch.rand(8, seq_len, c_dim_out)
+ self.assertTrue(output.shape, gt_shape_tensor.shape)
+
+ # Test pooling kernel without dim expansion.
+ multiscale_attention = MultiScaleAttention(
+ c_dim,
+ num_heads=2,
+ stride_q=(2, 2, 1),
+ )
+ output, output_shape = multiscale_attention(fake_input, input_shape)
+ gt_shape_tensor = torch.rand(8, 6, c_dim)
+ gt_output_shape = (1, 1, 5)
+ self.assertTrue(output.shape, gt_shape_tensor.shape)
+ self.assertTrue(output_shape, gt_output_shape)
+
+ # Test pooling kernel with dim expansion.
+ multiscale_attention = MultiScaleAttention(
+ c_dim,
+ dim_out=c_dim_out,
+ num_heads=2,
+ stride_q=(2, 2, 1),
+ )
+ output, output_shape = multiscale_attention(fake_input, input_shape)
+ gt_shape_tensor = torch.rand(8, 6, c_dim_out)
+ gt_output_shape = (1, 1, 5)
+ self.assertTrue(output.shape, gt_shape_tensor.shape)
+ self.assertTrue(output_shape, gt_output_shape)
+
+ # Test pooling kernel with no cls.
+ seq_len = 20
+ c_dim = 10
+ fake_input = torch.rand(8, seq_len, c_dim)
+ multiscale_attention = MultiScaleAttention(
+ c_dim, num_heads=2, stride_q=(2, 2, 1), has_cls_embed=False
+ )
+ output, output_shape = multiscale_attention(fake_input, input_shape)
+ gt_shape_tensor = torch.rand(8, int(seq_len / 2 / 2), c_dim)
+ gt_output_shape = [1, 1, 5]
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+ self.assertEqual(output_shape, gt_output_shape)
+
+ def test_MultiScaleBlock(self):
+ seq_len = 21
+ c_dim = 10
+ batch_dim = 8
+ fake_input = torch.rand(batch_dim, seq_len, c_dim)
+ input_shape = (2, 2, 5)
+
+ # Change of output dimension.
+ block = MultiScaleBlock(10, 20, 2)
+ output, output_shape = block(fake_input, input_shape)
+ gt_shape_tensor = torch.rand(8, seq_len, 20)
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+ self.assertEqual(output_shape, input_shape)
+
+ # Test dimension multiplication in attention
+ block = MultiScaleBlock(10, 20, 2, dim_mul_in_att=True)
+ output, output_shape = block(fake_input, input_shape)
+ gt_shape_tensor = torch.rand(8, seq_len, 20)
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+ self.assertEqual(output_shape, input_shape)
+
+ # Test pooling.
+ block = MultiScaleBlock(10, 20, 2, stride_q=(2, 2, 1))
+ output, output_shape = block(fake_input, input_shape)
+ gt_shape_tensor = torch.rand(8, int((seq_len - 1) / 2 / 2) + 1, 20)
+ gt_out_shape = [1, 1, 5]
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+ self.assertEqual(output_shape, gt_out_shape)
+
+ def test_Mlp(self):
+ fake_input = torch.rand((8, 64))
+ in_features = [10, 20, 30]
+ hidden_features = [10, 20, 20]
+ out_features = [10, 20, 30]
+ act_layers = [nn.GELU, nn.ReLU, nn.Sigmoid]
+ drop_rates = [0.0, 0.1, 0.5]
+ batch_size = 8
+ for in_feat, hidden_feat, out_feat, act_layer, drop_rate in itertools.product(
+ in_features, hidden_features, out_features, act_layers, drop_rates
+ ):
+ mlp_block = Mlp(
+ in_features=in_feat,
+ hidden_features=hidden_feat,
+ out_features=out_feat,
+ act_layer=act_layer,
+ dropout_rate=drop_rate,
+ )
+ fake_input = torch.rand((batch_size, in_feat))
+ output = mlp_block(fake_input)
+ self.assertTrue(output.shape, torch.Size([batch_size, out_feat]))
+
+ def test_MultiScaleBlock_is_scriptable(self):
+ iter_qkv_bias = [True, False]
+ iter_separate_qkv = [True, False]
+ iter_dropout_rate = [0.0, 0.1]
+ iter_droppath_rate = [0.0, 0.1]
+ iter_norm_layer = [nn.LayerNorm]
+ iter_attn_norm_layer = [nn.LayerNorm]
+ iter_dim_mul_in_att = [True, False]
+ iter_pool_mode = ["conv", "avg", "max"]
+ iter_has_cls_embed = [True, False]
+ iter_pool_first = [True, False]
+ iter_residual_pool = [True, False]
+ iter_depthwise_conv = [True, False]
+ iter_bias_on = [True, False]
+ iter_separate_qkv = [True, False]
+
+ for (
+ qkv_bias,
+ dropout_rate,
+ droppath_rate,
+ norm_layer,
+ attn_norm_layer,
+ dim_mul_in_att,
+ pool_mode,
+ has_cls_embed,
+ pool_first,
+ residual_pool,
+ depthwise_conv,
+ bias_on,
+ separate_qkv,
+ ) in itertools.product(
+ iter_qkv_bias,
+ iter_dropout_rate,
+ iter_droppath_rate,
+ iter_norm_layer,
+ iter_attn_norm_layer,
+ iter_dim_mul_in_att,
+ iter_pool_mode,
+ iter_has_cls_embed,
+ iter_pool_first,
+ iter_residual_pool,
+ iter_depthwise_conv,
+ iter_bias_on,
+ iter_separate_qkv,
+ ):
+ msb = MultiScaleBlock(
+ dim=10,
+ dim_out=20,
+ num_heads=2,
+ stride_q=(2, 2, 1),
+ qkv_bias=qkv_bias,
+ dropout_rate=dropout_rate,
+ droppath_rate=droppath_rate,
+ norm_layer=norm_layer,
+ attn_norm_layer=attn_norm_layer,
+ dim_mul_in_att=dim_mul_in_att,
+ pool_mode=pool_mode,
+ has_cls_embed=has_cls_embed,
+ pool_first=pool_first,
+ residual_pool=residual_pool,
+ depthwise_conv=depthwise_conv,
+ bias_on=bias_on,
+ separate_qkv=separate_qkv,
+ )
+ torch.jit.script(msb)
diff --git a/code/pytorchvideo/tests/test_layers_convolutions.py b/code/pytorchvideo/tests/test_layers_convolutions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f79129a533a3c753ec24bd0d653aeb3a81e73e2e
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_convolutions.py
@@ -0,0 +1,219 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+
+import numpy as np
+import torch
+from pytorchvideo.layers.convolutions import (
+ Conv2plus1d,
+ ConvReduce3D,
+ create_conv_2plus1d,
+)
+from torch import nn
+
+
+class TestConvReduce3D(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_stack_conv(self):
+ """
+ Test ConvReduce3D.
+ """
+ for input_dim, output_dim in itertools.product((2, 4), (4, 8, 16)):
+ model = ConvReduce3D(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=((1, 1, 1), (3, 3, 3), (1, 3, 3)),
+ stride=((1, 1, 1), (1, 1, 1), None),
+ padding=((0, 0, 0), (1, 1, 1), (0, 1, 1)),
+ dilation=((2, 2, 2), (1, 1, 1), None),
+ groups=(1, 2, None),
+ bias=(True, False, None),
+ )
+ model_gt_list = [
+ nn.Conv3d(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=(1, 1, 1),
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ dilation=(2, 2, 2),
+ groups=1,
+ bias=True,
+ ),
+ nn.Conv3d(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=(3, 3, 3),
+ stride=(1, 1, 1),
+ padding=(1, 1, 1),
+ dilation=(1, 1, 1),
+ groups=2,
+ bias=False,
+ ),
+ nn.Conv3d(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=(1, 3, 3),
+ padding=(0, 1, 1),
+ ),
+ ]
+ model.convs[0].load_state_dict(
+ model_gt_list[0].state_dict(), strict=True
+ ) # explicitly use strict mode.
+ model.convs[1].load_state_dict(
+ model_gt_list[1].state_dict(), strict=True
+ ) # explicitly use strict mode.
+ model.convs[2].load_state_dict(
+ model_gt_list[2].state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for tensor in TestConvReduce3D._get_inputs(input_dim):
+ if tensor.shape[1] != input_dim:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(tensor)
+ continue
+ else:
+ output_tensor = model(tensor)
+ output_gt = []
+ for ind in range(3):
+ output_gt.append(model_gt_list[ind](tensor))
+ output_tensor_gt = torch.stack(output_gt, dim=0).sum(
+ dim=0, keepdim=False
+ )
+
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+
+ @staticmethod
+ def _get_inputs(input_dim: int = 3) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random tensor as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, input_dim, 3, 7, 7),
+ (1, input_dim, 5, 7, 7),
+ (1, input_dim, 7, 7, 7),
+ (2, input_dim, 3, 7, 7),
+ (4, input_dim, 3, 7, 7),
+ (8, input_dim, 3, 7, 7),
+ (2, input_dim, 3, 7, 14),
+ (2, input_dim, 3, 14, 7),
+ (2, input_dim, 3, 14, 14),
+ # Forward failed.
+ (8, input_dim * 2, 3, 7, 7),
+ (8, input_dim * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
+
+
+class TestConv2plus1d(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_2plus1d_conv(self):
+ """
+ Test Conv2plus1d.
+ """
+ for input_dim, output_dim in itertools.product((2, 4), (4, 8, 16)):
+ model = Conv2plus1d(
+ conv_t=nn.Conv3d(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=(3, 1, 1),
+ stride=(2, 1, 1),
+ padding=(1, 0, 0),
+ bias=False,
+ ),
+ norm=nn.BatchNorm3d(output_dim),
+ activation=nn.ReLU(),
+ conv_xy=nn.Conv3d(
+ in_channels=output_dim,
+ out_channels=output_dim,
+ kernel_size=(1, 3, 3),
+ stride=(1, 2, 2),
+ padding=(0, 1, 1),
+ bias=False,
+ ),
+ )
+
+ model_gt = create_conv_2plus1d(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=(3, 3, 3),
+ stride=(2, 2, 2),
+ padding=(1, 1, 1),
+ bias=False,
+ norm=nn.BatchNorm3d,
+ norm_eps=1e-5,
+ norm_momentum=0.1,
+ activation=nn.ReLU,
+ )
+
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestConv2plus1d._get_inputs():
+ with torch.no_grad():
+ if input_tensor.shape[1] != input_dim:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ @staticmethod
+ def _get_inputs(input_dim: int = 3) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random tensor as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, input_dim, 3, 7, 7),
+ (1, input_dim, 5, 7, 7),
+ (1, input_dim, 7, 7, 7),
+ (2, input_dim, 3, 7, 7),
+ (4, input_dim, 3, 7, 7),
+ (8, input_dim, 3, 7, 7),
+ (2, input_dim, 3, 7, 14),
+ (2, input_dim, 3, 14, 7),
+ (2, input_dim, 3, 14, 14),
+ # Forward failed.
+ (8, input_dim * 2, 3, 7, 7),
+ (8, input_dim * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_layers_drop_path.py b/code/pytorchvideo/tests/test_layers_drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd84a02b23e8bd92129a3d06dfbb7c563f712793
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_drop_path.py
@@ -0,0 +1,24 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import torch
+from pytorchvideo.layers import DropPath
+
+
+class TestDropPath(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_dropPath(self):
+ # Input should be same if drop_prob = 0.
+ net_drop_path = DropPath(drop_prob=0.0)
+ fake_input = torch.rand(64, 10, 20)
+ output = net_drop_path(fake_input)
+ self.assertTrue(output.equal(fake_input))
+ # Test when drop_prob > 0.
+ net_drop_path = DropPath(drop_prob=0.5)
+ fake_input = torch.rand(64, 10, 20)
+ output = net_drop_path(fake_input)
+ self.assertTrue(output.shape, fake_input.shape)
diff --git a/code/pytorchvideo/tests/test_layers_fusion.py b/code/pytorchvideo/tests/test_layers_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc35b9f4504e46def9b40d158060524329abd152
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_fusion.py
@@ -0,0 +1,64 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import torch
+from pytorchvideo.layers import make_fusion_layer
+
+
+class TestFusion(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ self.fake_input_1 = torch.Tensor(
+ [[[4, -2], [3, 0]], [[0, 2], [4, 3]], [[3, 1], [5, 2]]]
+ ).float()
+ self.fake_input_2 = torch.Tensor(
+ [[[1, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [2, 1]]]
+ ).float()
+
+ def test_reduce_fusion_layers(self):
+ expected_output_for_method = {
+ "max": torch.Tensor(
+ [[[4, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [5, 2]]]
+ ).float(),
+ "sum": torch.Tensor(
+ [[[5, 0], [6, 4]], [[5, 8], [10, 8]], [[7, 4], [7, 3]]]
+ ).float(),
+ "prod": torch.Tensor(
+ [[[4, -4], [9, 0]], [[0, 12], [24, 15]], [[12, 3], [10, 2]]]
+ ).float(),
+ }
+
+ for method, expected_output in expected_output_for_method.items():
+ model = make_fusion_layer(
+ method, [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]]
+ )
+ output = model([self.fake_input_1, self.fake_input_2])
+ self.assertTrue(torch.equal(output, expected_output))
+ self.assertEqual(model.output_dim, self.fake_input_1.shape[-1])
+
+ def test_concat_fusion(self):
+ model = make_fusion_layer(
+ "concat", [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]]
+ )
+ input_list = [self.fake_input_1, self.fake_input_2]
+ output = model(input_list)
+ expected_output = torch.cat(input_list, dim=-1)
+ self.assertTrue(torch.equal(output, expected_output))
+
+ expected_shape = self.fake_input_1.shape[-1] + self.fake_input_2.shape[-1]
+ self.assertEqual(model.output_dim, expected_shape)
+
+ def test_temporal_concat_fusion(self):
+ model = make_fusion_layer(
+ "temporal_concat",
+ [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]],
+ )
+ input_list = [self.fake_input_1, self.fake_input_2]
+ output = model(input_list)
+
+ expected_output = torch.cat(input_list, dim=-2)
+ self.assertTrue(torch.equal(output, expected_output))
+ self.assertEqual(model.output_dim, self.fake_input_2.shape[-1])
diff --git a/code/pytorchvideo/tests/test_layers_mlp.py b/code/pytorchvideo/tests/test_layers_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..051ad318733666320aa442410782f9a56f89f572
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_mlp.py
@@ -0,0 +1,36 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers import make_multilayer_perceptron
+
+
+class TestMLP(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_make_multilayer_perceptron(self):
+ fake_input = torch.rand((8, 64))
+ fcs = [64, 128, 64, 32]
+ mid_activations = [nn.ReLU, nn.Sigmoid]
+ final_activations = [nn.ReLU, nn.Sigmoid, None]
+ norms = [nn.LayerNorm, nn.BatchNorm1d, None]
+ for mid_act, final_act, norm in itertools.product(
+ mid_activations, final_activations, norms
+ ):
+ mlp, output_dim = make_multilayer_perceptron(
+ fully_connected_dims=fcs,
+ mid_activation=mid_act,
+ final_activation=final_act,
+ norm=norm,
+ dropout_rate=0.5,
+ )
+
+ self.assertEqual(output_dim, 32)
+
+ output = mlp(fake_input)
+ self.assertTrue(output.shape, torch.Size([8, 32]))
diff --git a/code/pytorchvideo/tests/test_layers_nonlocal_net.py b/code/pytorchvideo/tests/test_layers_nonlocal_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f35aa55b3f04b8c8f2e84da8bfc8ee8193ab6b8
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_nonlocal_net.py
@@ -0,0 +1,159 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+from typing import Iterable
+
+import numpy as np
+import torch
+from pytorchvideo.layers.nonlocal_net import create_nonlocal, NonLocal
+from torch import nn
+
+
+class TestNonlocal(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_build_nonlocal(self):
+ """
+ Test Nonlocal model builder.
+ """
+ for dim_in, dim_inner, pool, norm, instantiation in itertools.product(
+ (4, 8),
+ (2, 4),
+ (None, nn.MaxPool3d(2)),
+ (None, nn.BatchNorm3d),
+ ("dot_product", "softmax"),
+ ):
+ model = NonLocal(
+ conv_theta=nn.Conv3d(
+ dim_in, dim_inner, kernel_size=1, stride=1, padding=0
+ ),
+ conv_phi=nn.Conv3d(
+ dim_in, dim_inner, kernel_size=1, stride=1, padding=0
+ ),
+ conv_g=nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0),
+ conv_out=nn.Conv3d(
+ dim_inner, dim_in, kernel_size=1, stride=1, padding=0
+ ),
+ pool=pool,
+ norm=norm(dim_in) if norm is not None else None,
+ instantiation=instantiation,
+ )
+
+ # Test forwarding.
+ for input_tensor in TestNonlocal._get_inputs(input_dim=dim_in):
+ if input_tensor.shape[1] != dim_in:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+
+ self.assertEqual(
+ input_shape,
+ output_shape,
+ "Input shape {} is different from output shape {}".format(
+ input_shape, output_shape
+ ),
+ )
+
+ def test_nonlocal_builder(self):
+ """
+ Test builder `create_nonlocal`.
+ """
+ for dim_in, dim_inner, pool_size, norm, instantiation in itertools.product(
+ (4, 8),
+ (2, 4),
+ ((1, 1, 1), (2, 2, 2)),
+ (None, nn.BatchNorm3d),
+ ("dot_product", "softmax"),
+ ):
+ conv_theta = nn.Conv3d(
+ dim_in, dim_inner, kernel_size=1, stride=1, padding=0
+ )
+ conv_phi = nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0)
+ conv_g = nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0)
+ conv_out = nn.Conv3d(dim_inner, dim_in, kernel_size=1, stride=1, padding=0)
+ if norm is None:
+ norm_model = None
+ else:
+ norm_model = norm(num_features=dim_in)
+ if isinstance(pool_size, Iterable) and any(size > 1 for size in pool_size):
+ pool_model = nn.MaxPool3d(
+ kernel_size=pool_size, stride=pool_size, padding=[0, 0, 0]
+ )
+ else:
+ pool_model = None
+
+ model = create_nonlocal(
+ dim_in=dim_in,
+ dim_inner=dim_inner,
+ pool_size=pool_size,
+ instantiation=instantiation,
+ norm=norm,
+ )
+
+ model_gt = NonLocal(
+ conv_theta=conv_theta,
+ conv_phi=conv_phi,
+ conv_g=conv_g,
+ conv_out=conv_out,
+ pool=pool_model,
+ norm=norm_model,
+ instantiation=instantiation,
+ )
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestNonlocal._get_inputs(input_dim=dim_in):
+ with torch.no_grad():
+ if input_tensor.shape[1] != dim_in:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ @staticmethod
+ def _get_inputs(input_dim: int = 8) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random tensor as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, input_dim, 5, 7, 7),
+ (2, input_dim, 5, 7, 7),
+ (4, input_dim, 5, 7, 7),
+ (4, input_dim, 5, 7, 7),
+ (4, input_dim, 7, 7, 7),
+ (4, input_dim, 7, 7, 14),
+ (4, input_dim, 7, 14, 7),
+ (4, input_dim, 7, 14, 14),
+ # Forward failed.
+ (8, input_dim * 2, 3, 7, 7),
+ (8, input_dim * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_layers_positional_encoding.py b/code/pytorchvideo/tests/test_layers_positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..babd7bd5088ff4c08a85e517213f65fb1034c460
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_positional_encoding.py
@@ -0,0 +1,143 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import math
+import unittest
+
+import torch
+from pytorchvideo.layers import PositionalEncoding, SpatioTemporalClsPositionalEncoding
+
+
+class TestPositionalEncoding(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ self.batch_size = 4
+ self.seq_len = 16
+ self.feature_dim = 8
+ self.fake_input = torch.randn(
+ (self.batch_size, self.seq_len, self.feature_dim)
+ ).float()
+ lengths = torch.Tensor([16, 0, 14, 15, 16, 16, 16, 16])
+ self.mask = torch.lt(
+ torch.arange(self.seq_len)[None, :], lengths[:, None].long()
+ )
+
+ def test_positional_encoding(self):
+ model = PositionalEncoding(self.feature_dim, self.seq_len)
+ output = model(self.fake_input)
+ delta = output - self.fake_input
+
+ pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
+ position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.feature_dim, 2).float()
+ * (-math.log(10000.0) / self.feature_dim)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ for n in range(0, self.batch_size):
+ self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))
+
+ def test_positional_encoding_with_different_pe_and_data_dimensions(self):
+ """Test that model executes even if input data dimensions
+ differs from the dimension of initialized postional encoding model"""
+
+ # When self.seq_len < positional_encoding_seq_len, pe is added to input
+ positional_encoding_seq_len = self.seq_len * 3
+ model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
+ output = model(self.fake_input)
+
+ delta = output - self.fake_input
+ pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
+ position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.feature_dim, 2).float()
+ * (-math.log(10000.0) / self.feature_dim)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ for n in range(0, self.batch_size):
+ self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))
+
+ # When self.seq_len > positional_encoding_seq_len, assertion error
+ positional_encoding_seq_len = self.seq_len // 2
+ model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
+ with self.assertRaises(AssertionError):
+ output = model(self.fake_input)
+
+ def test_SpatioTemporalClsPositionalEncoding(self):
+ # Test with cls token.
+ batch_dim = 4
+ dim = 16
+ video_shape = (1, 2, 4)
+ video_sum = video_shape[0] * video_shape[1] * video_shape[2]
+ has_cls = True
+ model = SpatioTemporalClsPositionalEncoding(
+ embed_dim=dim,
+ patch_embed_shape=video_shape,
+ has_cls=has_cls,
+ )
+ fake_input = torch.rand(batch_dim, video_sum, dim)
+ output = model(fake_input)
+ output_gt_shape = (batch_dim, video_sum + 1, dim)
+ self.assertEqual(tuple(output.shape), output_gt_shape)
+
+ def test_SpatioTemporalClsPositionalEncoding_nocls(self):
+ # Test without cls token.
+ batch_dim = 4
+ dim = 16
+ video_shape = (1, 2, 4)
+ video_sum = video_shape[0] * video_shape[1] * video_shape[2]
+ has_cls = False
+ model = SpatioTemporalClsPositionalEncoding(
+ embed_dim=dim,
+ patch_embed_shape=video_shape,
+ has_cls=has_cls,
+ )
+ fake_input = torch.rand(batch_dim, video_sum, dim)
+ output = model(fake_input)
+ output_gt_shape = (batch_dim, video_sum, dim)
+ self.assertEqual(tuple(output.shape), output_gt_shape)
+
+ def test_SpatioTemporalClsPositionalEncoding_mismatch(self):
+ # Mismatch in dimension for patch_embed_shape.
+ with self.assertRaises(AssertionError):
+ SpatioTemporalClsPositionalEncoding(
+ embed_dim=16,
+ patch_embed_shape=(1, 2),
+ )
+
+ def test_SpatioTemporalClsPositionalEncoding_scriptable(self):
+ iter_embed_dim = [1, 2, 4, 32]
+ iter_patch_embed_shape = [(1, 1, 1), (1, 2, 4), (32, 16, 1)]
+ iter_sep_pos_embed = [True, False]
+ iter_has_cls = [True, False]
+
+ for (
+ embed_dim,
+ patch_embed_shape,
+ sep_pos_embed,
+ has_cls,
+ ) in itertools.product(
+ iter_embed_dim,
+ iter_patch_embed_shape,
+ iter_sep_pos_embed,
+ iter_has_cls,
+ ):
+ stcpe = SpatioTemporalClsPositionalEncoding(
+ embed_dim=embed_dim,
+ patch_embed_shape=patch_embed_shape,
+ sep_pos_embed=sep_pos_embed,
+ has_cls=has_cls,
+ )
+ stcpe_scripted = torch.jit.script(stcpe)
+ batch_dim = 4
+ video_dim = math.prod(patch_embed_shape)
+ fake_input = torch.rand(batch_dim, video_dim, embed_dim)
+ expected = stcpe(fake_input)
+ actual = stcpe_scripted(fake_input)
+ torch.testing.assert_allclose(expected, actual)
diff --git a/code/pytorchvideo/tests/test_layers_squeeze_excitation.py b/code/pytorchvideo/tests/test_layers_squeeze_excitation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ceaffe41996b82169d5e2014bdc3481ffde977e
--- /dev/null
+++ b/code/pytorchvideo/tests/test_layers_squeeze_excitation.py
@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import copy
+import unittest
+
+import torch
+import torch.nn as nn
+from pytorchvideo.layers.squeeze_excitation import (
+ create_audio_2d_squeeze_excitation_block,
+)
+
+
+class Test2DSqueezeExcitationBlock(unittest.TestCase):
+ def setUp(self):
+
+ self.layer_args = {
+ "dim_in": 32,
+ "dim_out": 32,
+ "use_se": True,
+ "se_reduction_ratio": 16,
+ "branch_fusion": lambda x, y: x + y,
+ "conv_a_kernel_size": 3,
+ "conv_a_stride": 1,
+ "conv_a_padding": 1,
+ "conv_b_kernel_size": 3,
+ "conv_b_stride": 1,
+ "conv_b_padding": 1,
+ "norm": nn.BatchNorm2d,
+ "norm_eps": 1e-5,
+ "norm_momentum": 0.1,
+ "activation": nn.ReLU,
+ }
+
+ self.batchsize = 1
+ self.forward_pass_configs = [
+ {
+ "input": torch.rand(self.batchsize, self.layer_args["dim_in"], 100, 40),
+ "output_shape": torch.Size(
+ [self.batchsize, self.layer_args["dim_out"], 100, 40]
+ ),
+ },
+ ]
+
+ def test_forward_pass(self):
+ for split_config in self.forward_pass_configs:
+ layer_args = copy.deepcopy(self.layer_args)
+ model = create_audio_2d_squeeze_excitation_block(**layer_args)
+
+ out = model(split_config["input"])
+ self.assertTrue(isinstance(out, torch.Tensor))
+ self.assertEqual(out.size(), split_config["output_shape"])
diff --git a/code/pytorchvideo/tests/test_losses_soft_target_cross_entropy.py b/code/pytorchvideo/tests/test_losses_soft_target_cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2477441964ab5d7b82b75a1a7866d99a7a4c8f2e
--- /dev/null
+++ b/code/pytorchvideo/tests/test_losses_soft_target_cross_entropy.py
@@ -0,0 +1,73 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+
+import torch
+import torch.nn.functional as F
+from pytorchvideo.losses.soft_target_cross_entropy import SoftTargetCrossEntropyLoss
+
+
+class TestSoftTargetCrossEntropyLoss(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_soft_target_cross_entropy_loss(self):
+ """
+ Test the soft target cross entropy loss.
+ """
+ for batch_size, num_class, use_1D_target in itertools.product(
+ (1, 8), (2, 10), (True, False)
+ ):
+ loss = SoftTargetCrossEntropyLoss()
+
+ # Test forwarding.
+ for (
+ input_tensor,
+ target_tensor,
+ ) in TestSoftTargetCrossEntropyLoss._get_inputs(
+ batch_size=batch_size, num_class=num_class, use_1D_target=use_1D_target
+ ):
+ output_tensor = loss(input_tensor, target_tensor)
+ output_shape = output_tensor.shape
+
+ self.assertEqual(
+ output_shape,
+ torch.Size([]),
+ "Output shape {} is different from expected.".format(output_shape),
+ )
+
+ # If target is normalized, output_tensor must match direct eval
+ if target_tensor.ndim == 1 or all(target_tensor.sum(dim=-1) == 1):
+
+ _target_tensor = target_tensor
+ if target_tensor.ndim == 1:
+ _target_tensor = torch.nn.functional.one_hot(
+ target_tensor, num_class
+ )
+
+ _output_tensor = torch.sum(
+ -_target_tensor * F.log_softmax(input_tensor, dim=-1), dim=-1
+ ).mean()
+
+ self.assertTrue(abs(_output_tensor - output_tensor) < 1e-6)
+
+ @staticmethod
+ def _get_inputs(
+ batch_size: int = 16, num_class: int = 400, use_1D_target: bool = True
+ ) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random tensor as test cases.
+ if use_1D_target:
+ target_shape = (batch_size,)
+ else:
+ target_shape = (batch_size, num_class)
+ input_shape = (batch_size, num_class)
+
+ yield torch.rand(input_shape), torch.randint(num_class, target_shape)
diff --git a/code/pytorchvideo/tests/test_models_audio_visual_slowfast.py b/code/pytorchvideo/tests/test_models_audio_visual_slowfast.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e66410108b714e3ecac94f1d4ad895cd1a6f5d9
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_audio_visual_slowfast.py
@@ -0,0 +1,68 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+from typing import Tuple
+
+import torch
+from pytorchvideo.models.audio_visual_slowfast import create_audio_visual_slowfast
+from pytorchvideo.transforms.functional import uniform_temporal_subsample_repeated
+from torch import nn
+
+
+class TestAVSlowFast(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_avslowfast_with_callable(self):
+ """
+ Test builder `create_audio_visual_slowfast` with callable inputs.
+ """
+ for (norm, activation) in itertools.product(
+ (nn.BatchNorm3d, None), (nn.ReLU, nn.Sigmoid, None)
+ ):
+ input_channel = 3
+
+ model = create_audio_visual_slowfast(
+ input_channels=(input_channel, input_channel, 1),
+ model_depth=18,
+ norm=norm,
+ activation=activation,
+ )
+
+ # Test forwarding.
+ for tensor in TestAVSlowFast._get_inputs(input_channel):
+ with torch.no_grad():
+ if tensor[0].shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor)
+ continue
+
+ model(tensor)
+
+ @staticmethod
+ def _get_inputs(
+ channel: int = 3,
+ clip_length: int = 64,
+ audio_clip_length: int = 128,
+ crop_size: int = 224,
+ audio_size: int = 80,
+ frame_ratios: Tuple[int] = (8, 2),
+ audio_frame_ratio: int = 1,
+ ) -> Tuple[torch.Tensor]:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ Tuple[torch.Tensor]: tensors as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shape = (1, channel, clip_length, crop_size, crop_size)
+ audio_shape = (1, 1, audio_clip_length, 1, audio_size)
+ output = uniform_temporal_subsample_repeated(
+ torch.rand(shape), frame_ratios=frame_ratios, temporal_dim=2
+ )
+ yield output + uniform_temporal_subsample_repeated(
+ torch.rand(audio_shape), frame_ratios=(audio_frame_ratio,), temporal_dim=2
+ )
diff --git a/code/pytorchvideo/tests/test_models_byol.py b/code/pytorchvideo/tests/test_models_byol.py
new file mode 100644
index 0000000000000000000000000000000000000000..6007439a284135ffe6cf535d223c7085d41c1a3d
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_byol.py
@@ -0,0 +1,36 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import torch
+from pytorchvideo.models.byol import BYOL
+from torch import nn
+
+
+class TestBYOL(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_byol(self):
+ byol = BYOL(
+ backbone=nn.Linear(8, 4),
+ projector=nn.Linear(4, 4),
+ feature_dim=4,
+ norm=nn.BatchNorm1d,
+ )
+ for crop1, crop2 in TestBYOL._get_inputs():
+ byol(crop1, crop2)
+
+ @staticmethod
+ def _get_inputs() -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = ((2, 8),)
+ for shape in shapes:
+ yield torch.rand(shape), torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_models_csn.py b/code/pytorchvideo/tests/test_models_csn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b96b88fb3696b6315553a7b6ec605fde4355c70e
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_csn.py
@@ -0,0 +1,120 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import os
+import unittest
+
+import numpy as np
+import torch
+from pytorchvideo.models.csn import create_csn
+from pytorchvideo.models.resnet import create_bottleneck_block
+from torch import nn
+
+
+class TestCSN(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_csn(self):
+ """
+ Test simple CSN with different inputs.
+ """
+ for input_channel, input_clip_length, input_crop_size in itertools.product(
+ (3, 2), (4, 8), (56, 64)
+ ):
+ stage_spatial_stride = (1, 2, 2, 2)
+ stage_temporal_stride = (1, 2, 2, 1)
+
+ total_spatial_stride = 2 * np.prod(stage_spatial_stride)
+ total_temporal_stride = np.prod(stage_temporal_stride)
+ head_pool_kernel_size = (
+ input_clip_length // total_temporal_stride,
+ input_crop_size // total_spatial_stride,
+ input_crop_size // total_spatial_stride,
+ )
+
+ model = create_csn(
+ input_channel=input_channel,
+ model_depth=50,
+ model_num_class=400,
+ dropout_rate=0,
+ norm=nn.BatchNorm3d,
+ activation=nn.ReLU,
+ stem_dim_out=8,
+ stem_conv_kernel_size=(3, 7, 7),
+ stem_conv_stride=(1, 2, 2),
+ stage_conv_a_kernel_size=(1, 1, 1),
+ stage_conv_b_kernel_size=(3, 3, 3),
+ stage_conv_b_width_per_group=1,
+ stage_spatial_stride=(1, 2, 2, 2),
+ stage_temporal_stride=(1, 2, 2, 1),
+ bottleneck=create_bottleneck_block,
+ head_pool=nn.AvgPool3d,
+ head_pool_kernel_size=head_pool_kernel_size,
+ head_output_size=(1, 1, 1),
+ head_activation=nn.Softmax,
+ )
+
+ # Test forwarding.
+ for tensor in TestCSN._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ output_shape = out.shape
+ output_shape_gt = (tensor.shape[0], 400)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ @staticmethod
+ def _get_inputs(
+ channel: int = 3, clip_length: int = 4, crop_size: int = 112
+ ) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ (1, channel, clip_length, crop_size, crop_size),
+ (2, channel, clip_length, crop_size, crop_size),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
+
+ def test_load_hubconf(self):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ input_channel = 3
+ input_clip_length = 4
+ input_crop_size = 56
+ model = torch.hub.load(
+ repo_or_dir=path, source="local", model="csn_r101", pretrained=False
+ )
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ for tensor in TestCSN._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor)
+ continue
diff --git a/code/pytorchvideo/tests/test_models_head.py b/code/pytorchvideo/tests/test_models_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9e980870265ee2f5e17e81246b2e5451bd495d5
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_head.py
@@ -0,0 +1,402 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+
+import numpy as np
+import torch
+from pytorchvideo.models.head import (
+ create_res_basic_head,
+ create_res_roi_pooling_head,
+ create_vit_basic_head,
+ ResNetBasicHead,
+ ResNetRoIHead,
+ SequencePool,
+)
+from torch import nn
+from torchvision.ops import RoIAlign
+
+
+class TestHeadHelper(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_build_simple_head(self):
+ """
+ Test simple ResNetBasicHead (without dropout and activation layers).
+ """
+ for input_dim, output_dim in itertools.product((4, 8), (4, 8, 16)):
+ model = ResNetBasicHead(
+ proj=nn.Linear(input_dim, output_dim),
+ pool=nn.AdaptiveAvgPool3d(1),
+ output_pool=nn.AdaptiveAvgPool3d(1),
+ )
+
+ # Test forwarding.
+ for input_tensor in TestHeadHelper._get_inputs(input_dim=input_dim):
+ if input_tensor.shape[1] != input_dim:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+ output_shape_gt = (input_shape[0], output_dim)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_build_complex_head(self):
+ """
+ Test complex ResNetBasicHead.
+ """
+ for input_dim, output_dim in itertools.product((4, 8), (4, 8, 16)):
+ model = ResNetBasicHead(
+ proj=nn.Linear(input_dim, output_dim),
+ activation=nn.Softmax(),
+ pool=nn.AdaptiveAvgPool3d(1),
+ dropout=nn.Dropout(0.5),
+ output_pool=nn.AdaptiveAvgPool3d(1),
+ )
+
+ # Test forwarding.
+ for input_tensor in TestHeadHelper._get_inputs(input_dim=input_dim):
+ if input_tensor.shape[1] != input_dim:
+ with self.assertRaises(Exception):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+ output_shape_gt = (input_shape[0], output_dim)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_build_head_with_callable(self):
+ """
+ Test builder `create_res_basic_head`.
+ """
+ for (pool, activation) in itertools.product(
+ (nn.AvgPool3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d, None),
+ (nn.ReLU, nn.Softmax, nn.Sigmoid, None),
+ ):
+ if activation is None:
+ activation_model = None
+ elif activation == nn.Softmax:
+ activation_model = activation(dim=1)
+ else:
+ activation_model = activation()
+
+ if pool is None:
+ pool_model = None
+ elif pool == nn.AdaptiveAvgPool3d:
+ pool_model = pool(1)
+ else:
+ pool_model = pool(kernel_size=[5, 7, 7], stride=[1, 1, 1])
+
+ model = create_res_basic_head(
+ in_features=16,
+ out_features=32,
+ pool=pool,
+ pool_kernel_size=(5, 7, 7),
+ output_size=(1, 1, 1),
+ dropout_rate=0.0,
+ activation=activation,
+ output_with_global_average=True,
+ )
+ model_gt = ResNetBasicHead(
+ proj=nn.Linear(16, 32),
+ activation=activation_model,
+ pool=pool_model,
+ dropout=None,
+ output_pool=nn.AdaptiveAvgPool3d(1),
+ )
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestHeadHelper._get_inputs(input_dim=16):
+ with torch.no_grad():
+ if input_tensor.shape[1] != 16:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ @staticmethod
+ def _get_inputs(input_dim: int = 8) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random tensor as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, input_dim, 5, 7, 7),
+ (2, input_dim, 5, 7, 7),
+ (4, input_dim, 5, 7, 7),
+ (4, input_dim, 5, 7, 7),
+ (4, input_dim, 7, 7, 7),
+ (4, input_dim, 7, 7, 14),
+ (4, input_dim, 7, 14, 7),
+ (4, input_dim, 7, 14, 14),
+ # Forward failed.
+ (8, input_dim * 2, 3, 7, 7),
+ (8, input_dim * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
+
+
+class TestRoIHeadHelper(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_build_simple_head(self):
+ """
+ Test simple ResNetRoIHead
+ (without pool_spatial, roi, dropout and activation layers).
+ """
+ for input_dim, output_dim in itertools.product((4, 8), (4, 8, 16)):
+ model = ResNetRoIHead(
+ proj=nn.Linear(input_dim, output_dim),
+ pool=nn.AdaptiveAvgPool3d(1),
+ output_pool=nn.AdaptiveAvgPool3d(1),
+ )
+ bboxes = None
+
+ # Test forwarding.
+ for input_tensor in TestHeadHelper._get_inputs(input_dim=input_dim):
+ if input_tensor.shape[1] != input_dim:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor, bboxes)
+ continue
+ else:
+ output_tensor = model(input_tensor, bboxes)
+
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+ output_shape_gt = (input_shape[0], output_dim)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_vit_basic_head(self):
+ batch_size = 8
+ seq_len = 10
+ input_dim = 10
+ out_dim = 20
+ head = create_vit_basic_head(
+ in_features=input_dim,
+ out_features=out_dim,
+ )
+ fake_input = torch.rand(batch_size, seq_len, input_dim)
+ output = head(fake_input)
+ gt_shape = (batch_size, out_dim)
+ self.assertEqual(tuple(output.shape), gt_shape)
+
+ def test_sequence_pool(self):
+ model = SequencePool("cls")
+ fake_input = torch.rand(8, 10, 10)
+ output = model(fake_input)
+ self.assertTrue(torch.equal(output, fake_input[:, 0]))
+ model = SequencePool("mean")
+ output = model(fake_input)
+ self.assertTrue(torch.equal(output, fake_input.mean(1)))
+
+ def test_build_complex_head(self):
+ """
+ Test complex ResNetRoIHead.
+ """
+ # ROI layer configs
+ resolution = (10, 15)
+ spatial_scale = 1.0 / 5.0
+ sampling_ratio = 0
+ roi_layer = RoIAlign(
+ resolution, spatial_scale=spatial_scale, sampling_ratio=sampling_ratio
+ )
+
+ for input_dim, output_dim in itertools.product((4, 8), (4, 8, 16)):
+
+ model = ResNetRoIHead(
+ proj=nn.Linear(input_dim, output_dim),
+ activation=nn.Softmax(),
+ pool=nn.AdaptiveAvgPool3d(1),
+ pool_spatial=nn.MaxPool2d(resolution, stride=1),
+ roi_layer=roi_layer,
+ dropout=nn.Dropout(0.5),
+ output_pool=nn.AdaptiveAvgPool3d(1),
+ )
+
+ # Test forwarding.
+ for (input_tensor, bboxes) in TestRoIHeadHelper._get_inputs(
+ input_dim=input_dim
+ ):
+ if input_tensor.shape[1] != input_dim:
+ with self.assertRaises(Exception):
+ output_tensor = model(input_tensor, bboxes)
+ continue
+ output_tensor = model(input_tensor, bboxes)
+
+ bboxes_shape = bboxes.shape
+ output_shape = output_tensor.shape
+ output_shape_gt = (bboxes_shape[0], output_dim)
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_build_head_with_callable(self):
+ """
+ Test builder `create_res_roi_pooling_head`.
+ """
+ # ROI layer configs
+ resolution = (10, 15)
+ spatial_scale = 1.0 / 5.0
+ sampling_ratio = 0
+ roi_layer = RoIAlign(
+ resolution, spatial_scale=spatial_scale, sampling_ratio=sampling_ratio
+ )
+
+ for (pool, activation) in itertools.product(
+ (nn.AvgPool3d, nn.MaxPool3d, nn.AdaptiveAvgPool3d, None),
+ (nn.ReLU, nn.Softmax, nn.Sigmoid, None),
+ ):
+ if activation is None:
+ activation_model = None
+ elif activation == nn.Softmax:
+ activation_model = activation(dim=1)
+ else:
+ activation_model = activation()
+
+ if pool is None:
+ pool_model = None
+ elif pool == nn.AdaptiveAvgPool3d:
+ pool_model = pool(1)
+ else:
+ pool_model = pool(kernel_size=[5, 1, 1], stride=[1, 1, 1])
+
+ model = create_res_roi_pooling_head(
+ in_features=16,
+ out_features=32,
+ resolution=resolution,
+ spatial_scale=spatial_scale,
+ sampling_ratio=sampling_ratio,
+ roi=RoIAlign,
+ pool=pool,
+ pool_spatial=nn.MaxPool2d,
+ pool_kernel_size=(5, 1, 1),
+ output_size=(1, 1, 1),
+ dropout_rate=0.0,
+ activation=activation,
+ output_with_global_average=True,
+ )
+ model_gt = ResNetRoIHead(
+ proj=nn.Linear(16, 32),
+ activation=activation_model,
+ pool=pool_model,
+ pool_spatial=nn.MaxPool2d(resolution, stride=1),
+ roi_layer=roi_layer,
+ dropout=None,
+ output_pool=nn.AdaptiveAvgPool3d(1),
+ )
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for (input_tensor, bboxes) in TestRoIHeadHelper._get_inputs(input_dim=16):
+ with torch.no_grad():
+ if (
+ input_tensor.shape[1] != 16
+ or (pool is None)
+ or (
+ input_tensor.shape[-3] != 5 and pool != nn.AdaptiveAvgPool3d
+ )
+ ):
+ with self.assertRaises(Exception):
+ output_tensor = model(input_tensor, bboxes)
+ continue
+ else:
+ output_tensor = model(input_tensor, bboxes)
+ output_tensor_gt = model_gt(input_tensor, bboxes)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ @staticmethod
+ def _get_inputs(input_dim: int = 8) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ (torch.tensor): tensor as test case bboxes.
+ """
+ # Prepare random tensor as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, input_dim, 5, 7, 7),
+ (2, input_dim, 5, 7, 7),
+ (4, input_dim, 5, 7, 7),
+ (4, input_dim, 5, 7, 7),
+ (4, input_dim, 7, 7, 7),
+ (4, input_dim, 7, 7, 14),
+ (4, input_dim, 7, 14, 7),
+ (4, input_dim, 7, 14, 14),
+ # Forward failed.
+ (8, input_dim * 2, 3, 7, 7),
+ (8, input_dim * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ input_tensor = torch.rand(shape)
+ bboxes = [[i, 1, 2, 3, 4] for i in range(input_tensor.shape[0])]
+ bboxes = torch.Tensor(bboxes)
+ yield (input_tensor, bboxes)
diff --git a/code/pytorchvideo/tests/test_models_hub_vision_transformers.py b/code/pytorchvideo/tests/test_models_hub_vision_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..814bc12cba68d8477cd2ac79669de9775f51caac
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_hub_vision_transformers.py
@@ -0,0 +1,79 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+import unittest
+
+import torch
+import torch.nn as nn
+from pytorchvideo.models.hub.utils import hub_model_builder
+
+
+class TestHubVisionTransformers(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_load_hubconf(self):
+ def test_load_mvit_(model_name, pretrained):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ model = torch.hub.load(
+ repo_or_dir=path,
+ source="local",
+ model=model_name,
+ pretrained=pretrained,
+ )
+ self.assertIsNotNone(model)
+
+ models = [
+ "mvit_base_16x4",
+ "mvit_base_16",
+ "mvit_base_32x3",
+ ]
+ pretrains = [False, False, False]
+
+ for model_name, pretrain in zip(models, pretrains):
+ test_load_mvit_(model_name, pretrain)
+
+ def test_hub_model_builder(self):
+ def _fake_model(in_features=10, out_features=10) -> nn.Module:
+ """
+ A fake model builder with a linear layer.
+ """
+ model = nn.Linear(in_features, out_features)
+ return model
+
+ in_fea = 5
+ default_config = {"in_features": in_fea}
+ model = hub_model_builder(
+ model_builder_func=_fake_model, default_config=default_config
+ )
+ self.assertEqual(model.in_features, in_fea)
+ self.assertEqual(model.out_features, 10)
+
+ # Test case where add_config overwrites default_config.
+ in_fea = 5
+ default_config = {"in_features": in_fea}
+ add_in_fea = 2
+ add_out_fea = 3
+
+ model = hub_model_builder(
+ model_builder_func=_fake_model,
+ default_config=default_config,
+ in_features=add_in_fea,
+ out_features=add_out_fea,
+ )
+ self.assertEqual(model.in_features, add_in_fea)
+ self.assertEqual(model.out_features, add_out_fea)
+
+ # Test assertions.
+ self.assertRaises(
+ AssertionError,
+ hub_model_builder,
+ model_builder_func=_fake_model,
+ pretrained=True,
+ default_config={},
+ fake_input=None,
+ )
diff --git a/code/pytorchvideo/tests/test_models_masked_multistream.py b/code/pytorchvideo/tests/test_models_masked_multistream.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa97b48797ea9d1623aa47fcafad834b85d1ca3c
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_masked_multistream.py
@@ -0,0 +1,130 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import copy
+import unittest
+
+import torch
+import torch.nn
+from pytorchvideo.layers import make_multilayer_perceptron, PositionalEncoding
+from pytorchvideo.models.masked_multistream import (
+ LearnMaskedDefault,
+ LSTM,
+ MaskedSequential,
+ MaskedTemporalPooling,
+ TransposeMultiheadAttention,
+ TransposeTransformerEncoder,
+)
+
+
+class TestMaskedMultiStream(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_masked_multistream_model(self):
+ feature_dim = 8
+ mlp, out_dim = make_multilayer_perceptron([feature_dim, 2])
+ input_stream = MaskedSequential(
+ PositionalEncoding(feature_dim),
+ TransposeMultiheadAttention(feature_dim),
+ MaskedTemporalPooling(method="avg"),
+ torch.nn.LayerNorm(feature_dim),
+ mlp,
+ LearnMaskedDefault(out_dim),
+ )
+
+ seq_len = 10
+ input_tensor = torch.rand([4, seq_len, feature_dim])
+ mask = _lengths2mask(
+ torch.tensor([seq_len, seq_len, seq_len, seq_len]), input_tensor.shape[1]
+ )
+ output = input_stream(input=input_tensor, mask=mask)
+ self.assertEqual(output.shape, torch.Size([4, out_dim]))
+
+ def test_masked_temporal_pooling(self):
+ fake_input = torch.Tensor(
+ [[[4, -2], [3, 0]], [[0, 2], [4, 3]], [[3, 1], [5, 2]]]
+ ).float()
+ valid_lengths = torch.Tensor([2, 1, 0]).int()
+ valid_mask = _lengths2mask(valid_lengths, fake_input.shape[1])
+ expected_output_for_method = {
+ "max": torch.Tensor([[4, 0], [0, 2], [0, 0]]).float(),
+ "avg": torch.Tensor([[3.5, -1], [0, 2], [0, 0]]).float(),
+ "sum": torch.Tensor([[7, -2], [0, 2], [0, 0]]).float(),
+ }
+ for method, expected_output in expected_output_for_method.items():
+ model = MaskedTemporalPooling(method)
+ output = model(copy.deepcopy(fake_input), mask=valid_mask)
+ self.assertTrue(torch.equal(output, expected_output))
+
+ def test_transpose_attention(self):
+ feature_dim = 8
+ seq_len = 10
+ fake_input = torch.rand([4, seq_len, feature_dim])
+ mask = _lengths2mask(
+ torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
+ )
+ model = TransposeMultiheadAttention(feature_dim, num_heads=2)
+ output = model(fake_input, mask=mask)
+ self.assertTrue(output.shape, fake_input.shape)
+
+ def test_masked_lstm(self):
+ feature_dim = 8
+ seq_len = 10
+ fake_input = torch.rand([4, seq_len, feature_dim])
+ mask = _lengths2mask(
+ torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
+ )
+ hidden_dim = 128
+
+ model = LSTM(feature_dim, hidden_dim=hidden_dim, bidirectional=False)
+ output = model(fake_input, mask=mask)
+ self.assertTrue(output.shape, (fake_input.shape[0], hidden_dim))
+
+ model = LSTM(feature_dim, hidden_dim=hidden_dim, bidirectional=True)
+ output = model(fake_input, mask=mask)
+ self.assertTrue(output.shape, (fake_input.shape[0], hidden_dim * 2))
+
+ def test_masked_transpose_transformer_encoder(self):
+ feature_dim = 8
+ seq_len = 10
+ fake_input = torch.rand([4, seq_len, feature_dim])
+ mask = _lengths2mask(
+ torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
+ )
+
+ model = TransposeTransformerEncoder(feature_dim)
+ output = model(fake_input, mask=mask)
+ self.assertEqual(output.shape, (fake_input.shape[0], feature_dim))
+
+ def test_learn_masked_default(self):
+ feature_dim = 8
+ seq_len = 10
+ fake_input = torch.rand([4, feature_dim])
+
+ # All valid mask
+ all_valid_mask = _lengths2mask(
+ torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
+ )
+ model = LearnMaskedDefault(feature_dim)
+ output = model(fake_input, mask=all_valid_mask)
+ self.assertTrue(output.equal(fake_input))
+
+ # No valid mask
+ no_valid_mask = _lengths2mask(torch.tensor([0, 0, 0, 0]), fake_input.shape[1])
+ model = LearnMaskedDefault(feature_dim)
+ output = model(fake_input, mask=no_valid_mask)
+ self.assertTrue(output.equal(model._learned_defaults.repeat(4, 1)))
+
+ # Half valid mask
+ half_valid_mask = _lengths2mask(torch.tensor([1, 1, 0, 0]), fake_input.shape[1])
+ model = LearnMaskedDefault(feature_dim)
+ output = model(fake_input, mask=half_valid_mask)
+ self.assertTrue(output[:2].equal(fake_input[:2]))
+ self.assertTrue(output[2:].equal(model._learned_defaults.repeat(2, 1)))
+
+
+def _lengths2mask(lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ return torch.lt(
+ torch.arange(seq_len, device=lengths.device)[None, :], lengths[:, None].long()
+ )
diff --git a/code/pytorchvideo/tests/test_models_memory_bank.py b/code/pytorchvideo/tests/test_models_memory_bank.py
new file mode 100644
index 0000000000000000000000000000000000000000..19aa4b450ece2420346c7236597ba0882c55dc2c
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_memory_bank.py
@@ -0,0 +1,37 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import torch
+from pytorchvideo.models.memory_bank import MemoryBank
+from torch import nn
+
+
+class TestMemoryBank(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_memory_bank(self):
+ simclr = MemoryBank(
+ backbone=nn.Linear(8, 4),
+ mlp=nn.Linear(4, 2),
+ temperature=0.07,
+ bank_size=8,
+ dim=2,
+ )
+ for crop, ind in TestMemoryBank._get_inputs():
+ simclr(crop, ind)
+
+ @staticmethod
+ def _get_inputs(bank_size: int = 8) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = ((2, 8),)
+ for shape in shapes:
+ yield torch.rand(shape), torch.randint(0, bank_size, size=(shape[0],))
diff --git a/code/pytorchvideo/tests/test_models_r2plus1d.py b/code/pytorchvideo/tests/test_models_r2plus1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f86e38b9e8a306b582d928ab75861aa90b1fa428
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_r2plus1d.py
@@ -0,0 +1,102 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+
+import numpy as np
+import torch
+from pytorchvideo.models.r2plus1d import (
+ create_2plus1d_bottleneck_block,
+ create_r2plus1d,
+)
+from pytorchvideo.models.resnet import create_bottleneck_block
+from torch import nn
+
+
+class TestR2plus1d(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_r2plus1d(self):
+ """
+ Test simple r2plus1d with different inputs.
+ """
+ for input_channel, input_clip_length, input_crop_size in itertools.product(
+ (3, 2), (4, 8), (56, 64)
+ ):
+ stage_spatial_stride = (2, 2, 2, 2)
+ stage_temporal_stride = (1, 1, 2, 2)
+
+ total_spatial_stride = 2 * np.prod(stage_spatial_stride)
+ total_temporal_stride = np.prod(stage_temporal_stride)
+ head_pool_kernel_size = (
+ input_clip_length // total_temporal_stride,
+ input_crop_size // total_spatial_stride,
+ input_crop_size // total_spatial_stride,
+ )
+
+ model = create_r2plus1d(
+ input_channel=input_channel,
+ model_depth=50,
+ model_num_class=400,
+ dropout_rate=0.0,
+ norm=nn.BatchNorm3d,
+ activation=nn.ReLU,
+ stem_dim_out=8,
+ stem_conv_kernel_size=(1, 7, 7),
+ stem_conv_stride=(1, 2, 2),
+ stage_conv_b_kernel_size=((3, 3, 3),) * 4,
+ stage_spatial_stride=stage_spatial_stride,
+ stage_temporal_stride=stage_temporal_stride,
+ stage_bottleneck=(
+ create_bottleneck_block,
+ create_2plus1d_bottleneck_block,
+ create_2plus1d_bottleneck_block,
+ create_2plus1d_bottleneck_block,
+ ),
+ head_pool=nn.AvgPool3d,
+ head_pool_kernel_size=head_pool_kernel_size,
+ head_output_size=(1, 1, 1),
+ head_activation=nn.Softmax,
+ )
+
+ # Test forwarding.
+ for tensor in TestR2plus1d._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ output_shape = out.shape
+ output_shape_gt = (tensor.shape[0], 400)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ @staticmethod
+ def _get_inputs(
+ channel: int = 3, clip_length: int = 16, crop_size: int = 224
+ ) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ (1, channel, clip_length, crop_size, crop_size),
+ (2, channel, clip_length, crop_size, crop_size),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_models_resnet.py b/code/pytorchvideo/tests/test_models_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a95c2f4233295e80746eada10a7fcae709346a37
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_resnet.py
@@ -0,0 +1,1420 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import os
+import unittest
+
+import numpy as np
+import torch
+from pytorchvideo.models.head import ResNetBasicHead
+from pytorchvideo.models.net import Net
+from pytorchvideo.models.resnet import (
+ BottleneckBlock,
+ create_acoustic_bottleneck_block,
+ create_acoustic_resnet,
+ create_bottleneck_block,
+ create_res_block,
+ create_res_stage,
+ create_resnet,
+ create_resnet_with_roi_head,
+ ResBlock,
+ ResStage,
+ SeparableBottleneckBlock,
+)
+from pytorchvideo.models.stem import ResNetBasicStem
+from torch import nn
+
+
+class TestBottleneckBlock(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_simple_bottleneck_block(self):
+ """
+ Test simple BottleneckBlock with different dimensions.
+ """
+ for dim_in, dim_inner, dim_out in itertools.product(
+ (4, 8, 16), (2, 4), (4, 8, 16)
+ ):
+ model = BottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_in, dim_inner, kernel_size=1, stride=1, padding=0, bias=False
+ ),
+ norm_a=nn.BatchNorm3d(dim_inner),
+ act_a=nn.ReLU(),
+ conv_b=nn.Conv3d(
+ dim_inner, dim_inner, kernel_size=3, stride=1, padding=1, bias=False
+ ),
+ norm_b=nn.BatchNorm3d(dim_inner),
+ act_b=nn.ReLU(),
+ conv_c=nn.Conv3d(
+ dim_inner, dim_out, kernel_size=1, stride=1, padding=0, bias=False
+ ),
+ norm_c=nn.BatchNorm3d(dim_out),
+ )
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in):
+ if input_tensor.shape[1] != dim_in:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+
+ output_shape_gt = (
+ input_shape[0],
+ dim_out,
+ input_shape[2],
+ input_shape[3],
+ input_shape[4],
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_complex_bottleneck_block(self):
+ """
+ Test complex BottleneckBlock with different dimensions.
+ """
+ for dim_in, dim_inner, dim_out in itertools.product(
+ (4, 8, 16), (2, 4), (4, 8, 16)
+ ):
+ model = BottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_in,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[2, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=nn.BatchNorm3d(dim_inner),
+ act_a=nn.ReLU(),
+ conv_b=nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 2, 2],
+ padding=[0, 1, 1],
+ groups=1,
+ dilation=[1, 1, 1],
+ bias=False,
+ ),
+ norm_b=nn.BatchNorm3d(dim_inner),
+ act_b=nn.ReLU(),
+ conv_c=nn.Conv3d(
+ dim_inner,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=nn.BatchNorm3d(dim_out),
+ )
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in):
+ if input_tensor.shape[1] != dim_in:
+ with self.assertRaises(Exception):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+
+ output_shape_gt = (
+ input_shape[0],
+ dim_out,
+ (input_shape[2] - 1) // 2 + 1,
+ (input_shape[3] - 1) // 2 + 1,
+ (input_shape[4] - 1) // 2 + 1,
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_separable_bottleneck_block_sum(self):
+ """
+ Test SeparableBottleneckBlock with different dimensions.
+ """
+ for dim_in, dim_inner, dim_out in itertools.product(
+ (4, 8, 16), (2, 4), (4, 8, 16)
+ ):
+ model = SeparableBottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_in,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[2, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=nn.BatchNorm3d(dim_inner),
+ act_a=nn.ReLU(),
+ conv_b=nn.ModuleList(
+ [
+ nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 2, 2],
+ padding=[0, 1, 1],
+ groups=1,
+ dilation=[1, 1, 1],
+ bias=False,
+ ),
+ nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 2, 2],
+ padding=[0, 1, 1],
+ groups=1,
+ dilation=[1, 1, 1],
+ bias=False,
+ ),
+ ]
+ ),
+ norm_b=nn.ModuleList(
+ [nn.BatchNorm3d(dim_inner), nn.BatchNorm3d(dim_inner)]
+ ),
+ act_b=nn.ModuleList([nn.ReLU(), nn.ReLU()]),
+ conv_c=nn.Conv3d(
+ dim_inner,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=nn.BatchNorm3d(dim_out),
+ reduce_method="sum",
+ )
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in):
+ if input_tensor.shape[1] != dim_in:
+ with self.assertRaises(Exception):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+
+ output_shape_gt = (
+ input_shape[0],
+ dim_out,
+ (input_shape[2] - 1) // 2 + 1,
+ (input_shape[3] - 1) // 2 + 1,
+ (input_shape[4] - 1) // 2 + 1,
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_separable_complex_bottleneck_block_cat(self):
+ """
+ Test SeparableBottleneckBlock with different dimensions.
+ """
+ for dim_in, dim_inner, dim_out in itertools.product(
+ (4, 8, 16), (2, 4), (4, 8, 16)
+ ):
+ model = SeparableBottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_in,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[2, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=nn.BatchNorm3d(dim_inner),
+ act_a=nn.ReLU(),
+ conv_b=nn.ModuleList(
+ [
+ nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 2, 2],
+ padding=[0, 1, 1],
+ groups=1,
+ dilation=[1, 1, 1],
+ bias=False,
+ ),
+ nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 2, 2],
+ padding=[0, 1, 1],
+ groups=1,
+ dilation=[1, 1, 1],
+ bias=False,
+ ),
+ ]
+ ),
+ norm_b=nn.ModuleList(
+ [nn.BatchNorm3d(dim_inner), nn.BatchNorm3d(dim_inner)]
+ ),
+ act_b=nn.ModuleList([nn.ReLU(), nn.ReLU()]),
+ conv_c=nn.Conv3d(
+ dim_inner * 2,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=nn.BatchNorm3d(dim_out),
+ reduce_method="cat",
+ )
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in):
+ if input_tensor.shape[1] != dim_in:
+ with self.assertRaises(Exception):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+
+ output_shape_gt = (
+ input_shape[0],
+ dim_out,
+ (input_shape[2] - 1) // 2 + 1,
+ (input_shape[3] - 1) // 2 + 1,
+ (input_shape[4] - 1) // 2 + 1,
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_acoustic_bottleneck_block_with_callable(self):
+ """
+ Test builder `create_acoustic_bottleneck_block` with callable inputs.
+ """
+ for (norm_model, act_model) in itertools.product(
+ (nn.BatchNorm3d,), (nn.ReLU, nn.Softmax, nn.Sigmoid)
+ ):
+ model = create_acoustic_bottleneck_block(
+ dim_in=32,
+ dim_inner=16,
+ dim_out=64,
+ conv_a_kernel_size=(3, 1, 1),
+ conv_a_stride=(1, 1, 1),
+ conv_a_padding=(1, 0, 0),
+ conv_b_kernel_size=(3, 3, 3),
+ conv_b_stride=(1, 1, 1),
+ conv_b_padding=(1, 1, 1),
+ conv_b_num_groups=1,
+ conv_b_dilation=(1, 1, 1),
+ norm=norm_model,
+ activation=act_model,
+ )
+ model_gt = SeparableBottleneckBlock(
+ conv_a=nn.Conv3d(
+ 32,
+ 16,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=norm_model(16),
+ act_a=act_model(),
+ conv_b=nn.ModuleList(
+ [
+ nn.Conv3d(
+ 16,
+ 16,
+ kernel_size=[1, 3, 3],
+ stride=[1, 1, 1],
+ padding=[0, 1, 1],
+ dilation=1,
+ bias=False,
+ ),
+ nn.Conv3d(
+ 16,
+ 16,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ dilation=1,
+ bias=False,
+ ),
+ ]
+ ),
+ norm_b=nn.ModuleList([norm_model(16), norm_model(16)]),
+ act_b=nn.ModuleList([act_model(), act_model()]),
+ conv_c=nn.Conv3d(
+ 16,
+ 64,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=norm_model(64),
+ )
+
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in=32):
+ with torch.no_grad():
+ if input_tensor.shape[1] != 32:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ def test_create_bottleneck_block_with_callable(self):
+ """
+ Test builder `create_bottleneck_block` with callable inputs.
+ """
+ for (norm_model, act_model) in itertools.product(
+ (nn.BatchNorm3d,), (nn.ReLU, nn.Softmax, nn.Sigmoid)
+ ):
+ model = create_bottleneck_block(
+ dim_in=32,
+ dim_inner=16,
+ dim_out=64,
+ conv_a_kernel_size=(3, 1, 1),
+ conv_a_stride=(1, 1, 1),
+ conv_a_padding=(1, 0, 0),
+ conv_b_kernel_size=(1, 3, 3),
+ conv_b_stride=(1, 1, 1),
+ conv_b_padding=(0, 1, 1),
+ conv_b_num_groups=1,
+ conv_b_dilation=(1, 1, 1),
+ norm=norm_model,
+ activation=act_model,
+ )
+ model_gt = BottleneckBlock(
+ conv_a=nn.Conv3d(
+ 32,
+ 16,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=norm_model(16),
+ act_a=act_model(),
+ conv_b=nn.Conv3d(
+ 16,
+ 16,
+ kernel_size=[1, 3, 3],
+ stride=[1, 1, 1],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=norm_model(16),
+ act_b=act_model(),
+ conv_c=nn.Conv3d(
+ 16,
+ 64,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=norm_model(64),
+ )
+
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in=32):
+ with torch.no_grad():
+ if input_tensor.shape[1] != 32:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ @staticmethod
+ def _get_inputs(dim_in: int = 3) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random segmentation as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, dim_in, 3, 7, 7),
+ (1, dim_in, 5, 7, 7),
+ (1, dim_in, 7, 7, 7),
+ (2, dim_in, 3, 7, 7),
+ (4, dim_in, 3, 7, 7),
+ (8, dim_in, 3, 7, 7),
+ (2, dim_in, 3, 7, 14),
+ (2, dim_in, 3, 14, 7),
+ (2, dim_in, 3, 14, 14),
+ # Forward failed.
+ (8, dim_in * 2, 3, 7, 7),
+ (8, dim_in * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
+
+
+class TestResBottleneckBlock(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_res_block(self):
+ """
+ Test simple ResBlock with different inputs.
+ """
+ for dim_in, dim_inner, dim_out in itertools.product(
+ (4, 8, 16), (2, 4), (4, 8, 16)
+ ):
+ model = ResBlock(
+ branch1_conv=nn.Conv3d(
+ dim_in, dim_out, kernel_size=(1, 1, 1), stride=(1, 1, 1)
+ )
+ if dim_in != dim_out
+ else None,
+ branch1_norm=nn.BatchNorm3d(num_features=dim_out)
+ if dim_in != dim_out
+ else None,
+ branch2=BottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_in,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=nn.BatchNorm3d(dim_inner),
+ act_a=nn.ReLU(),
+ conv_b=nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 1, 1],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=nn.BatchNorm3d(dim_inner),
+ act_b=nn.ReLU(),
+ conv_c=nn.Conv3d(
+ dim_inner,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=nn.BatchNorm3d(dim_out),
+ ),
+ activation=nn.ReLU(),
+ branch_fusion=lambda x, y: x + y,
+ )
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in):
+ if input_tensor.shape[1] != dim_in:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+ output_shape_gt = (
+ input_shape[0],
+ dim_out,
+ input_shape[2],
+ input_shape[3],
+ input_shape[4],
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_res_block_with_callable(self):
+ """
+ Test builder `create_res_block` with callable inputs.
+ """
+ for (norm, activation) in itertools.product(
+ (nn.BatchNorm3d, None), (nn.ReLU, nn.Softmax, nn.Sigmoid, None)
+ ):
+ model = create_res_block(
+ dim_in=32,
+ dim_inner=16,
+ dim_out=64,
+ bottleneck=create_bottleneck_block,
+ conv_a_kernel_size=(3, 1, 1),
+ conv_a_stride=(1, 1, 1),
+ conv_a_padding=(1, 0, 0),
+ conv_b_kernel_size=(1, 3, 3),
+ conv_b_stride=(1, 2, 2),
+ conv_b_padding=(0, 1, 1),
+ conv_b_num_groups=1,
+ conv_b_dilation=(1, 1, 1),
+ norm=norm,
+ norm_eps=1e-5,
+ norm_momentum=0.1,
+ activation_bottleneck=activation,
+ activation_block=activation,
+ )
+ model_gt = ResBlock(
+ branch1_conv=nn.Conv3d(
+ 32, 64, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False
+ ),
+ branch1_norm=None if norm is None else norm(num_features=64),
+ branch2=BottleneckBlock(
+ conv_a=nn.Conv3d(
+ 32,
+ 16,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=None if norm is None else norm(16),
+ act_a=None if activation is None else activation(),
+ conv_b=nn.Conv3d(
+ 16,
+ 16,
+ kernel_size=[1, 3, 3],
+ stride=[1, 2, 2],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=None if norm is None else norm(16),
+ act_b=None if activation is None else activation(),
+ conv_c=nn.Conv3d(
+ 16,
+ 64,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=None if norm is None else norm(64),
+ ),
+ activation=None if activation is None else activation(),
+ branch_fusion=lambda x, y: x + y,
+ )
+
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestBottleneckBlock._get_inputs(dim_in=32):
+ with torch.no_grad():
+ if input_tensor.shape[1] != 32:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ @staticmethod
+ def _get_inputs(dim_in: int = 3) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, dim_in, 3, 7, 7),
+ (1, dim_in, 5, 7, 7),
+ (1, dim_in, 7, 7, 7),
+ (2, dim_in, 3, 7, 7),
+ (4, dim_in, 3, 7, 7),
+ (8, dim_in, 3, 7, 7),
+ (2, dim_in, 3, 7, 14),
+ (2, dim_in, 3, 14, 7),
+ (2, dim_in, 3, 14, 14),
+ # Forward failed.
+ (8, dim_in * 2, 3, 7, 7),
+ (8, dim_in * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
+
+
+class TestResStageTransform(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_res_stage(self):
+ """
+ Test simple ResStage with different inputs.
+ """
+ for dim_in, dim_inner, dim_out in itertools.product(
+ (4, 8, 16), (2, 4), (4, 8, 16)
+ ):
+ model = ResStage(
+ res_blocks=nn.ModuleList(
+ [
+ ResBlock(
+ branch1_conv=nn.Conv3d(
+ dim_in, dim_out, kernel_size=(1, 1, 1)
+ )
+ if dim_in != dim_out
+ else None,
+ branch1_norm=nn.BatchNorm3d(num_features=dim_out)
+ if dim_in != dim_out
+ else None,
+ branch2=BottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_in,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=nn.BatchNorm3d(dim_inner),
+ act_a=nn.ReLU(),
+ conv_b=nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 1, 1],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=nn.BatchNorm3d(dim_inner),
+ act_b=nn.ReLU(),
+ conv_c=nn.Conv3d(
+ dim_inner,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=nn.BatchNorm3d(dim_out),
+ ),
+ activation=nn.ReLU(),
+ branch_fusion=lambda x, y: x + y,
+ ),
+ ResBlock(
+ branch1_conv=None,
+ branch1_norm=None,
+ branch2=BottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_out,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=nn.BatchNorm3d(dim_inner),
+ act_a=nn.ReLU(),
+ conv_b=nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 1, 1],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=nn.BatchNorm3d(dim_inner),
+ act_b=nn.ReLU(),
+ conv_c=nn.Conv3d(
+ dim_inner,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=nn.BatchNorm3d(dim_out),
+ ),
+ activation=nn.ReLU(),
+ branch_fusion=lambda x, y: x + y,
+ ),
+ ]
+ )
+ )
+
+ # Test forwarding.
+ for tensor in TestResStageTransform._get_inputs(dim_in):
+ if tensor.shape[1] != dim_in:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ input_shape = tensor.shape
+ output_shape = out.shape
+ output_shape_gt = (
+ input_shape[0],
+ dim_out,
+ input_shape[2],
+ input_shape[3],
+ input_shape[4],
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_res_stage_with_callable(self):
+ """
+ Test builder `create_res_stage` with callable inputs.
+ """
+ dim_in, dim_inner, dim_out = 32, 16, 64
+ for (norm, activation) in itertools.product(
+ (nn.BatchNorm3d, None), (nn.ReLU, nn.Sigmoid, None)
+ ):
+ model = create_res_stage(
+ depth=2,
+ dim_in=dim_in,
+ dim_inner=dim_inner,
+ dim_out=dim_out,
+ bottleneck=create_bottleneck_block,
+ conv_a_kernel_size=(3, 1, 1),
+ conv_a_stride=(1, 1, 1),
+ conv_a_padding=(1, 0, 0),
+ conv_b_kernel_size=(1, 3, 3),
+ conv_b_stride=(1, 1, 1),
+ conv_b_padding=(0, 1, 1),
+ conv_b_num_groups=1,
+ conv_b_dilation=(1, 1, 1),
+ norm=norm,
+ norm_eps=1e-5,
+ norm_momentum=0.1,
+ activation=activation,
+ )
+ model_gt = ResStage(
+ res_blocks=nn.ModuleList(
+ [
+ ResBlock(
+ branch1_conv=nn.Conv3d(
+ dim_in, dim_out, kernel_size=(1, 1, 1), bias=False
+ )
+ if dim_in != dim_out
+ else None,
+ branch1_norm=None
+ if norm is None
+ else norm(num_features=dim_out)
+ if dim_in != dim_out
+ else None,
+ branch2=BottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_in,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=None if norm is None else norm(dim_inner),
+ act_a=None if activation is None else activation(),
+ conv_b=nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 1, 1],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=None if norm is None else norm(dim_inner),
+ act_b=None if activation is None else activation(),
+ conv_c=nn.Conv3d(
+ dim_inner,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=None if norm is None else norm(dim_out),
+ ),
+ activation=None if activation is None else activation(),
+ branch_fusion=lambda x, y: x + y,
+ ),
+ ResBlock(
+ branch1_conv=None,
+ branch1_norm=None,
+ branch2=BottleneckBlock(
+ conv_a=nn.Conv3d(
+ dim_out,
+ dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[1, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=None if norm is None else norm(dim_inner),
+ act_a=None if activation is None else activation(),
+ conv_b=nn.Conv3d(
+ dim_inner,
+ dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, 1, 1],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=None if norm is None else norm(dim_inner),
+ act_b=None if activation is None else activation(),
+ conv_c=nn.Conv3d(
+ dim_inner,
+ dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=None if norm is None else norm(dim_out),
+ ),
+ activation=None if activation is None else activation(),
+ branch_fusion=lambda x, y: x + y,
+ ),
+ ]
+ )
+ )
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for tensor in TestResStageTransform._get_inputs(dim_in=dim_in):
+ with torch.no_grad():
+ if tensor.shape[1] != 32:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+ out_gt = model_gt(tensor)
+
+ self.assertEqual(
+ out.shape,
+ out_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ out.shape, out_gt.shape
+ ),
+ )
+ self.assertTrue(np.allclose(out.numpy(), out_gt.numpy()))
+
+ @staticmethod
+ def _get_inputs(dim_in: int = 3) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, dim_in, 3, 7, 7),
+ (1, dim_in, 5, 7, 7),
+ (1, dim_in, 7, 7, 7),
+ (2, dim_in, 3, 7, 7),
+ (4, dim_in, 3, 7, 7),
+ (8, dim_in, 3, 7, 7),
+ (2, dim_in, 3, 7, 14),
+ (2, dim_in, 3, 14, 7),
+ (2, dim_in, 3, 14, 14),
+ # Forward failed.
+ (8, dim_in * 2, 3, 7, 7),
+ (8, dim_in * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
+
+
+class TestResNet(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def _build_resnet(
+ self,
+ input_channel,
+ input_clip_length,
+ input_crop_size,
+ model_depth,
+ norm,
+ activation,
+ ):
+ _MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
+ stem_dim_out = 8
+ model_num_class = 10
+ stages = []
+ # create the Stem for ResNet
+ stem = ResNetBasicStem(
+ conv=nn.Conv3d(
+ input_channel,
+ stem_dim_out,
+ kernel_size=[3, 7, 7],
+ stride=[1, 2, 2],
+ padding=[1, 3, 3],
+ bias=False,
+ ),
+ norm=None if norm is None else norm(stem_dim_out),
+ activation=None if activation is None else activation(),
+ pool=nn.MaxPool3d(
+ kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]
+ ),
+ )
+ stages.append(stem)
+
+ # get the number of Blocks for each Stage
+ stage_depths = _MODEL_STAGE_DEPTH[model_depth]
+
+ stage_dim_in = stem_dim_out
+ stage_dim_out = stage_dim_in * 4
+ stage_spatial_stride = (2, 1, 1, 1)
+ stage_temporal_stride = (2, 1, 1, 1)
+
+ # create each Stage for ResNet
+ for i in range(len(stage_depths)):
+ stage_dim_inner = stage_dim_out // 4
+ depth = stage_depths[i]
+
+ block_dim_in = stage_dim_in
+ block_dim_inner = stage_dim_inner
+ block_dim_out = stage_dim_out
+
+ blocks = []
+ for j in range(depth):
+ spatial_stride = stage_spatial_stride[i] if j == 0 else 1
+ temporal_stride = stage_temporal_stride[i] if j == 0 else 1
+ # create each Block for the Stage
+ block = ResBlock(
+ branch1_conv=nn.Conv3d(
+ block_dim_in,
+ block_dim_out,
+ kernel_size=(1, 1, 1),
+ stride=(temporal_stride, spatial_stride, spatial_stride),
+ bias=False,
+ )
+ if block_dim_in != block_dim_out
+ else None,
+ branch1_norm=None
+ if norm is None
+ else norm(block_dim_out)
+ if block_dim_in != block_dim_out
+ else None,
+ branch2=BottleneckBlock(
+ conv_a=nn.Conv3d(
+ block_dim_in,
+ block_dim_inner,
+ kernel_size=[3, 1, 1],
+ stride=[temporal_stride, 1, 1],
+ padding=[1, 0, 0],
+ bias=False,
+ ),
+ norm_a=None if norm is None else norm(block_dim_inner),
+ act_a=None if activation is None else activation(),
+ conv_b=nn.Conv3d(
+ block_dim_inner,
+ block_dim_inner,
+ kernel_size=[1, 3, 3],
+ stride=[1, spatial_stride, spatial_stride],
+ padding=[0, 1, 1],
+ bias=False,
+ ),
+ norm_b=None if norm is None else norm(block_dim_inner),
+ act_b=None if activation is None else activation(),
+ conv_c=nn.Conv3d(
+ block_dim_inner,
+ block_dim_out,
+ kernel_size=[1, 1, 1],
+ stride=[1, 1, 1],
+ padding=[0, 0, 0],
+ bias=False,
+ ),
+ norm_c=None if norm is None else norm(block_dim_out),
+ ),
+ activation=None if activation is None else activation(),
+ branch_fusion=lambda x, y: x + y,
+ )
+
+ block_dim_in = block_dim_out
+ blocks.append(block)
+
+ stage = ResStage(nn.ModuleList(blocks))
+ stages.append(stage)
+
+ stage_dim_in = stage_dim_out
+ stage_dim_out = stage_dim_out * 2
+
+ # Create Head for ResNet
+ total_spatial_stride = 4 * np.prod(stage_spatial_stride)
+ total_temporal_stride = np.prod(stage_temporal_stride)
+ head_pool_kernel_size = (
+ input_clip_length // total_temporal_stride,
+ input_crop_size // total_spatial_stride,
+ input_crop_size // total_spatial_stride,
+ )
+
+ head = ResNetBasicHead(
+ proj=nn.Linear(stage_dim_in, model_num_class),
+ activation=nn.Softmax(),
+ pool=nn.AvgPool3d(kernel_size=head_pool_kernel_size, stride=[1, 1, 1]),
+ dropout=None,
+ output_pool=nn.AdaptiveAvgPool3d(1),
+ )
+ stages.append(head)
+
+ return (Net(blocks=nn.ModuleList(stages)), model_num_class)
+
+ def test_create_resnet(self):
+ """
+ Test simple ResNet with different inputs.
+ """
+ for input_channel, input_clip_length, input_crop_size in itertools.product(
+ (3, 2), (2, 4), (56, 64)
+ ):
+ model_depth = 50
+ model, num_class = self._build_resnet(
+ input_channel,
+ input_clip_length,
+ input_crop_size,
+ model_depth,
+ nn.BatchNorm3d,
+ nn.ReLU,
+ )
+
+ # Test forwarding.
+ for tensor in TestResNet._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ output_shape = out.shape
+ output_shape_gt = (tensor.shape[0], num_class)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_resnet_with_callable(self):
+ """
+ Test builder `create_resnet` with callable inputs.
+ """
+ for (norm, activation) in itertools.product(
+ (nn.BatchNorm3d, None), (nn.ReLU, nn.Sigmoid, None)
+ ):
+ input_channel = 3
+ input_clip_length = 4
+ input_crop_size = 56
+ model_depth = 50
+ stage_spatial_stride = (2, 1, 1, 1)
+ stage_temporal_stride = (2, 1, 1, 1)
+ model_gt, num_class = self._build_resnet(
+ input_channel,
+ input_clip_length,
+ input_crop_size,
+ model_depth,
+ norm,
+ activation,
+ )
+
+ total_spatial_stride = 4 * np.prod(stage_spatial_stride)
+ total_temporal_stride = np.prod(stage_temporal_stride)
+ head_pool_kernel_size = (
+ input_clip_length // total_temporal_stride,
+ input_crop_size // total_spatial_stride,
+ input_crop_size // total_spatial_stride,
+ )
+
+ model = create_resnet(
+ input_channel=input_channel,
+ model_depth=50,
+ model_num_class=num_class,
+ dropout_rate=0,
+ norm=norm,
+ activation=activation,
+ stem_dim_out=8,
+ stem_conv_kernel_size=(3, 7, 7),
+ stem_conv_stride=(1, 2, 2),
+ stem_pool=nn.MaxPool3d,
+ stem_pool_kernel_size=(1, 3, 3),
+ stem_pool_stride=(1, 2, 2),
+ stage_conv_a_kernel_size=((3, 1, 1),) * 4,
+ stage_conv_b_kernel_size=((1, 3, 3),) * 4,
+ stage_spatial_h_stride=stage_spatial_stride,
+ stage_spatial_w_stride=stage_spatial_stride,
+ stage_temporal_stride=stage_temporal_stride,
+ bottleneck=create_bottleneck_block,
+ head_pool=nn.AvgPool3d,
+ head_pool_kernel_size=head_pool_kernel_size,
+ head_output_size=(1, 1, 1),
+ head_activation=nn.Softmax,
+ )
+
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for tensor in TestResNet._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+ out_gt = model_gt(tensor)
+
+ self.assertEqual(
+ out.shape,
+ out_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ out.shape, out_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(out.numpy(), out_gt.numpy(), rtol=1e-1, atol=1e-1)
+ )
+
+ def test_create_acoustic_resnet_with_callable(self):
+ """
+ Test builder `create_acoustic_resnet` with callable inputs.
+ """
+ _input_channel = 1
+ for (norm, activation) in itertools.product(
+ (nn.BatchNorm3d, None), (nn.ReLU, nn.Sigmoid, None)
+ ):
+ model = create_acoustic_resnet(
+ input_channel=_input_channel,
+ stem_conv_kernel_size=(3, 3, 3),
+ model_depth=50,
+ model_num_class=400,
+ dropout_rate=0,
+ norm=norm,
+ activation=activation,
+ stem_dim_out=8,
+ stem_pool=None,
+ stem_pool_kernel_size=(1, 3, 3),
+ stem_pool_stride=(1, 2, 2),
+ stage_conv_a_kernel_size=(3, 1, 1),
+ stage_conv_b_kernel_size=(1, 3, 3),
+ stage_spatial_h_stride=(2, 1, 1, 1),
+ stage_spatial_w_stride=(2, 1, 1, 1),
+ stage_temporal_stride=(2, 1, 1, 1),
+ head_pool=nn.AvgPool3d,
+ head_output_size=(1, 1, 1),
+ head_activation=nn.Softmax,
+ )
+
+ # Test forwarding.
+ for tensor in TestResNet._get_acoustic_inputs(_input_channel, 8, 56):
+ with torch.no_grad():
+ if tensor.shape[1] != _input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor)
+ continue
+ model(tensor)
+
+ def test_load_hubconf(self):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ input_channel = 3
+ input_clip_length = 2
+ input_crop_size = 56
+ model = torch.hub.load(
+ repo_or_dir=path, source="local", model="slow_r50", pretrained=False
+ )
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ for tensor in TestResNet._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor)
+ continue
+
+ def test_load_hubconf_detection(self):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ input_channel = 3
+ input_clip_length = 4
+ input_crop_size = 56
+ model = torch.hub.load(
+ repo_or_dir=path,
+ source="local",
+ model="slow_r50_detection",
+ pretrained=False,
+ )
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ bbox_test_imputs = torch.tensor([[0.0, 10, 15, 20, 25], [0.0, 11, 16, 21, 26]])
+ for tensor in TestResNet._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor, bbox_test_imputs)
+ continue
+ model(tensor, bbox_test_imputs)
+
+ def test_create_resnet_with_roi_head_with_callable(self):
+ input_channel = 3
+ input_clip_length = 4
+ input_crop_size = 56
+ model = create_resnet_with_roi_head()
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ bbox_test_imputs = torch.tensor([[0.0, 10, 15, 20, 25], [0.0, 11, 16, 21, 26]])
+ for tensor in TestResNet._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor.shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor, bbox_test_imputs)
+ continue
+ model(tensor, bbox_test_imputs)
+
+ @staticmethod
+ def _get_inputs(
+ channel: int = 3, clip_length: int = 8, crop_size: int = 224
+ ) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ (1, channel, clip_length, crop_size, crop_size),
+ (2, channel, clip_length, crop_size, crop_size),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
+
+ @staticmethod
+ def _get_acoustic_inputs(
+ channel: int = 1, clip_length: int = 130, freq_size: int = 80
+ ) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ (1, channel, clip_length, 1, freq_size),
+ (2, channel, clip_length, 1, freq_size),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_models_slowfast.py b/code/pytorchvideo/tests/test_models_slowfast.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f726c5f2beb3a279236c321444bcc9c28c3c70b
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_slowfast.py
@@ -0,0 +1,144 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import os
+import unittest
+from typing import Tuple
+
+import torch
+from pytorchvideo.models.slowfast import create_slowfast, create_slowfast_with_roi_head
+from pytorchvideo.transforms.functional import uniform_temporal_subsample_repeated
+from torch import nn
+
+
+class TestSlowFast(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_load_hubconf(self):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ for model_name in ["slowfast_r50", "slowfast_r101"]:
+ model = torch.hub.load(
+ repo_or_dir=path, source="local", model=model_name, pretrained=False
+ )
+ self.assertIsNotNone(model)
+
+ input_clip_length = 32
+ input_crop_size = 224
+ input_channel = 3
+ # Test forwarding.
+ for tensor in TestSlowFast._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor[0].shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor)
+ continue
+
+ def test_load_hubconf_detection(self):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ input_clip_length = 32
+ input_crop_size = 224
+ input_channel = 3
+ model = torch.hub.load(
+ repo_or_dir=path,
+ source="local",
+ model="slowfast_r50_detection",
+ pretrained=False,
+ )
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ bbox_test_imputs = torch.tensor([[0.0, 10, 15, 20, 25], [0.0, 11, 16, 21, 26]])
+ for tensor in TestSlowFast._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor[0].shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor, bbox_test_imputs)
+ continue
+ model(tensor, bbox_test_imputs)
+
+ def test_create_slowfast_with_roi_head_with_callable(self):
+ input_clip_length = 32
+ input_crop_size = 224
+ input_channel = 3
+ model = create_slowfast_with_roi_head()
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ bbox_test_imputs = torch.tensor([[0.0, 10, 15, 20, 25], [0.0, 11, 16, 21, 26]])
+ for tensor in TestSlowFast._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor[0].shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor, bbox_test_imputs)
+ continue
+ model(tensor, bbox_test_imputs)
+
+ def test_create_slowfast_with_callable(self):
+ """
+ Test builder `create_slowfast` with callable inputs.
+ """
+ for (norm, activation) in itertools.product(
+ (nn.BatchNorm3d, None), (nn.ReLU, nn.Sigmoid, None)
+ ):
+ input_clip_length = 32
+ input_crop_size = 224
+ input_channel = 3
+
+ model = create_slowfast(
+ slowfast_channel_reduction_ratio=8,
+ slowfast_conv_channel_fusion_ratio=2,
+ slowfast_fusion_conv_kernel_size=(7, 1, 1),
+ slowfast_fusion_conv_stride=(4, 1, 1),
+ input_channels=(input_channel,) * 2,
+ model_depth=18,
+ model_num_class=400,
+ dropout_rate=0,
+ norm=norm,
+ activation=activation,
+ )
+
+ # Test forwarding.
+ for tensor in TestSlowFast._get_inputs(
+ input_channel, input_clip_length, input_crop_size
+ ):
+ with torch.no_grad():
+ if tensor[0].shape[1] != input_channel:
+ with self.assertRaises(RuntimeError):
+ model(tensor)
+ continue
+
+ model(tensor)
+
+ @staticmethod
+ def _get_inputs(
+ channel: int = 3,
+ clip_length: int = 8,
+ crop_size: int = 224,
+ frame_ratios: Tuple[int] = (4, 1),
+ ) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = ((1, channel, clip_length, crop_size, crop_size),)
+ for shape in shapes:
+ yield uniform_temporal_subsample_repeated(
+ torch.rand(shape), frame_ratios=frame_ratios, temporal_dim=2
+ )
diff --git a/code/pytorchvideo/tests/test_models_stem.py b/code/pytorchvideo/tests/test_models_stem.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b6d1d16cf499fee257cdd755ad2f0ecbb1968a
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_stem.py
@@ -0,0 +1,303 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+
+import numpy as np
+import torch
+from pytorchvideo.layers.convolutions import ConvReduce3D
+from pytorchvideo.models.stem import (
+ create_acoustic_res_basic_stem,
+ create_res_basic_stem,
+ ResNetBasicStem,
+)
+from torch import nn
+
+
+class TestResNetBasicStem(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_simple_stem(self):
+ """
+ Test simple ResNetBasicStem (without pooling layer).
+ """
+ for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
+ model = ResNetBasicStem(
+ conv=nn.Conv3d(
+ input_dim,
+ output_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ norm=nn.BatchNorm3d(output_dim),
+ activation=nn.ReLU(),
+ pool=None,
+ )
+
+ # Test forwarding.
+ for tensor in TestResNetBasicStem._get_inputs(input_dim):
+ if tensor.shape[1] != input_dim:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(tensor)
+ continue
+ else:
+ output_tensor = model(tensor)
+
+ input_shape = tensor.shape
+ output_shape = output_tensor.shape
+ output_shape_gt = (
+ input_shape[0],
+ output_dim,
+ input_shape[2],
+ input_shape[3],
+ input_shape[4],
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_stem_with_conv_reduced_3d(self):
+ """
+ Test simple ResNetBasicStem with ConvReduce3D.
+ """
+ for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
+ model = ResNetBasicStem(
+ conv=ConvReduce3D(
+ in_channels=input_dim,
+ out_channels=output_dim,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=(False, False),
+ ),
+ norm=nn.BatchNorm3d(output_dim),
+ activation=nn.ReLU(),
+ pool=None,
+ )
+
+ # Test forwarding.
+ for tensor in TestResNetBasicStem._get_inputs(input_dim):
+ if tensor.shape[1] != input_dim:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(tensor)
+ continue
+ else:
+ output_tensor = model(tensor)
+
+ input_shape = tensor.shape
+ output_shape = output_tensor.shape
+ output_shape_gt = (
+ input_shape[0],
+ output_dim,
+ input_shape[2],
+ input_shape[3],
+ input_shape[4],
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_complex_stem(self):
+ """
+ Test complex ResNetBasicStem.
+ """
+ for input_dim, output_dim in itertools.product((2, 3), (4, 8, 16)):
+ model = ResNetBasicStem(
+ conv=nn.Conv3d(
+ input_dim,
+ output_dim,
+ kernel_size=[3, 7, 7],
+ stride=[1, 2, 2],
+ padding=[1, 3, 3],
+ bias=False,
+ ),
+ norm=nn.BatchNorm3d(output_dim),
+ activation=nn.ReLU(),
+ pool=nn.MaxPool3d(
+ kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]
+ ),
+ )
+
+ # Test forwarding.
+ for input_tensor in TestResNetBasicStem._get_inputs(input_dim):
+ if input_tensor.shape[1] != input_dim:
+ with self.assertRaises(Exception):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+
+ input_shape = input_tensor.shape
+ output_shape = output_tensor.shape
+
+ output_shape_gt = (
+ input_shape[0],
+ output_dim,
+ input_shape[2],
+ (((input_shape[3] - 1) // 2 + 1) - 1) // 2 + 1,
+ (((input_shape[4] - 1) // 2 + 1) - 1) // 2 + 1,
+ )
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_create_stem_with_callable(self):
+ """
+ Test builder `create_res_basic_stem` with callable inputs.
+ """
+ for (pool, activation, norm) in itertools.product(
+ (nn.AvgPool3d, nn.MaxPool3d, None),
+ (nn.ReLU, nn.Softmax, nn.Sigmoid, None),
+ (nn.BatchNorm3d, None),
+ ):
+ model = create_res_basic_stem(
+ in_channels=3,
+ out_channels=64,
+ pool=pool,
+ activation=activation,
+ norm=norm,
+ )
+ model_gt = ResNetBasicStem(
+ conv=nn.Conv3d(
+ 3,
+ 64,
+ kernel_size=[3, 7, 7],
+ stride=[1, 2, 2],
+ padding=[1, 3, 3],
+ bias=False,
+ ),
+ norm=None if norm is None else norm(64),
+ activation=None if activation is None else activation(),
+ pool=None
+ if pool is None
+ else pool(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]),
+ )
+
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestResNetBasicStem._get_inputs():
+ with torch.no_grad():
+ if input_tensor.shape[1] != 3:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ def test_create_acoustic_stem_with_callable(self):
+ """
+ Test builder `create_acoustic_res_basic_stem` with callable
+ inputs.
+ """
+ for (pool, activation, norm) in itertools.product(
+ (nn.AvgPool3d, nn.MaxPool3d, None),
+ (nn.ReLU, nn.Softmax, nn.Sigmoid, None),
+ (nn.BatchNorm3d, None),
+ ):
+ model = create_acoustic_res_basic_stem(
+ in_channels=3,
+ out_channels=64,
+ pool=pool,
+ activation=activation,
+ norm=norm,
+ )
+ model_gt = ResNetBasicStem(
+ conv=ConvReduce3D(
+ in_channels=3,
+ out_channels=64,
+ kernel_size=((3, 1, 1), (1, 7, 7)),
+ stride=((1, 1, 1), (1, 1, 1)),
+ padding=((1, 0, 0), (0, 3, 3)),
+ bias=(False, False),
+ ),
+ norm=None if norm is None else norm(64),
+ activation=None if activation is None else activation(),
+ pool=None
+ if pool is None
+ else pool(kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]),
+ )
+
+ model.load_state_dict(
+ model_gt.state_dict(), strict=True
+ ) # explicitly use strict mode.
+
+ # Test forwarding.
+ for input_tensor in TestResNetBasicStem._get_inputs():
+ with torch.no_grad():
+ if input_tensor.shape[1] != 3:
+ with self.assertRaises(RuntimeError):
+ output_tensor = model(input_tensor)
+ continue
+ else:
+ output_tensor = model(input_tensor)
+ output_tensor_gt = model_gt(input_tensor)
+ self.assertEqual(
+ output_tensor.shape,
+ output_tensor_gt.shape,
+ "Output shape {} is different from expected shape {}".format(
+ output_tensor.shape, output_tensor_gt.shape
+ ),
+ )
+ self.assertTrue(
+ np.allclose(output_tensor.numpy(), output_tensor_gt.numpy())
+ )
+
+ @staticmethod
+ def _get_inputs(input_dim: int = 3) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random tensor as test cases.
+ shapes = (
+ # Forward succeeded.
+ (1, input_dim, 3, 7, 7),
+ (1, input_dim, 5, 7, 7),
+ (1, input_dim, 7, 7, 7),
+ (2, input_dim, 3, 7, 7),
+ (4, input_dim, 3, 7, 7),
+ (8, input_dim, 3, 7, 7),
+ (2, input_dim, 3, 7, 14),
+ (2, input_dim, 3, 14, 7),
+ (2, input_dim, 3, 14, 14),
+ # Forward failed.
+ (8, input_dim * 2, 3, 7, 7),
+ (8, input_dim * 4, 5, 7, 7),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_models_vision_transformers.py b/code/pytorchvideo/tests/test_models_vision_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d6edea269d3c5dfb9d45261296252225ed646ab
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_vision_transformers.py
@@ -0,0 +1,183 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import itertools
+import unittest
+import warnings
+
+import torch
+from pytorchvideo.models.vision_transformers import (
+ create_multiscale_vision_transformers,
+)
+
+
+class TestVisionTransformers(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_mvit(self):
+ """
+ Test MViT.
+ """
+ # Test MViT with 3D case.
+ num_head = 100
+ batch_size = 1
+ fake_input = torch.rand(batch_size, 3, 4, 28, 28)
+ model = create_multiscale_vision_transformers(
+ spatial_size=28,
+ temporal_size=4,
+ patch_embed_dim=12,
+ depth=1,
+ head_num_classes=num_head,
+ pool_kv_stride_adaptive=[1, 2, 2],
+ )
+ output = model(fake_input)
+ gt_shape_tensor = torch.rand(batch_size, num_head)
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+ # Test MViT with 3D case with pool first.
+ num_head = 100
+ batch_size = 1
+ fake_input = torch.rand(batch_size, 3, 4, 28, 28)
+ model = create_multiscale_vision_transformers(
+ spatial_size=28,
+ temporal_size=4,
+ patch_embed_dim=12,
+ depth=1,
+ head_num_classes=num_head,
+ pool_first=True,
+ pool_q_stride_size=[[0, 1, 2, 2]],
+ )
+ output = model(fake_input)
+ gt_shape_tensor = torch.rand(batch_size, num_head)
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+
+ # Test MViT with 2D case for images.
+ conv_patch_kernel = (7, 7)
+ conv_patch_stride = (4, 4)
+ conv_patch_padding = (3, 3)
+ num_head = 100
+ batch_size = 1
+ fake_input = torch.rand(batch_size, 3, 28, 28)
+ model = create_multiscale_vision_transformers(
+ spatial_size=(28, 28),
+ temporal_size=1,
+ patch_embed_dim=12,
+ depth=1,
+ head_num_classes=num_head,
+ use_2d_patch=True,
+ conv_patch_embed_kernel=conv_patch_kernel,
+ conv_patch_embed_stride=conv_patch_stride,
+ conv_patch_embed_padding=conv_patch_padding,
+ )
+ output = model(fake_input)
+ gt_shape_tensor = torch.rand(batch_size, num_head)
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+
+ # Test MViT without patch_embed.
+ conv_patch_kernel = (7, 7)
+ conv_patch_stride = (4, 4)
+ conv_patch_padding = (3, 3)
+ num_head = 100
+ batch_size = 1
+ fake_input = torch.rand(batch_size, 8, 12)
+ model = create_multiscale_vision_transformers(
+ spatial_size=(8, 1),
+ temporal_size=1,
+ patch_embed_dim=12,
+ depth=1,
+ enable_patch_embed=False,
+ head_num_classes=num_head,
+ )
+ output = model(fake_input)
+ gt_shape_tensor = torch.rand(batch_size, num_head)
+ self.assertEqual(output.shape, gt_shape_tensor.shape)
+
+ self.assertRaises(
+ AssertionError,
+ create_multiscale_vision_transformers,
+ spatial_size=28,
+ temporal_size=4,
+ use_2d_patch=True,
+ )
+
+ self.assertRaises(
+ AssertionError,
+ create_multiscale_vision_transformers,
+ spatial_size=28,
+ temporal_size=1,
+ pool_kv_stride_adaptive=[[2, 2, 2]],
+ pool_kv_stride_size=[[1, 1, 2, 2]],
+ )
+
+ self.assertRaises(
+ NotImplementedError,
+ create_multiscale_vision_transformers,
+ spatial_size=28,
+ temporal_size=1,
+ norm="fakenorm",
+ )
+
+ def test_mvit_is_torchscriptable(self):
+ batch_size = 2
+ num_head = 4
+ spatial_size = (28, 28)
+ temporal_size = 4
+ depth = 2
+ patch_embed_dim = 96
+
+ # The following binary settings are covered by `test_layers_attention.py`:
+ # `qkv_bias`, `depthwise_conv`, `separate_qkv`, `bias_on` `pool_first`
+ # `residual_pool`
+ true_false_opts = [
+ "cls_embed_on",
+ "sep_pos_embed",
+ "enable_patch_embed",
+ "enable_patch_embed_norm",
+ ]
+
+ # Loop over `2 ^ len(true_false_opts)` configurations
+ for true_false_settings in itertools.product(
+ *([[True, False]] * len(true_false_opts))
+ ):
+ named_tf_settings = dict(zip(true_false_opts, true_false_settings))
+
+ model = create_multiscale_vision_transformers(
+ spatial_size=spatial_size,
+ temporal_size=temporal_size,
+ depth=depth,
+ head_num_classes=num_head,
+ patch_embed_dim=patch_embed_dim,
+ pool_kv_stride_adaptive=[1, 2, 2],
+ **named_tf_settings,
+ create_scriptable_model=False,
+ ).eval()
+ ts_model = torch.jit.script(model)
+
+ input_shape = (
+ (3, temporal_size, spatial_size[0], spatial_size[1])
+ if named_tf_settings["enable_patch_embed"]
+ else (
+ temporal_size * spatial_size[0] * spatial_size[1],
+ patch_embed_dim,
+ )
+ )
+ fake_input = torch.rand(batch_size, *input_shape)
+
+ expected = model(fake_input)
+ actual = ts_model(fake_input)
+ torch.testing.assert_allclose(expected, actual)
+
+ def test_mvit_create_scriptable_model_is_deprecated(self):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ create_multiscale_vision_transformers(
+ spatial_size=28,
+ temporal_size=4,
+ norm="batchnorm",
+ depth=2,
+ head_num_classes=100,
+ create_scriptable_model=True,
+ )
+
+ assert len(w) == 1
+ assert issubclass(w[-1].category, DeprecationWarning)
diff --git a/code/pytorchvideo/tests/test_models_x3d.py b/code/pytorchvideo/tests/test_models_x3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0014c0b1dc3d0149d17a023b8a5ea452542d766
--- /dev/null
+++ b/code/pytorchvideo/tests/test_models_x3d.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+import unittest
+
+import torch
+from pytorchvideo.layers.swish import Swish
+from pytorchvideo.models.x3d import create_x3d, create_x3d_bottleneck_block
+from torch import nn
+
+
+class TestX3d(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_create_x3d(self):
+ """
+ To test different versions of X3D, set the input to:
+ X3D-XS: (4, 160, 2.0, 2.2, 2.25)
+ X3D-S: (13, 160, 2.0, 2.2, 2.25)
+ X3D-M: (16, 224, 2.0, 2.2, 2.25)
+ X3D-L: (16, 312, 2.0, 5.0, 2.25)
+
+ Each of the parameters corresponds to input_clip_length, input_crop_size,
+ width_factor, depth_factor and bottleneck_factor.
+ """
+ for (
+ input_clip_length,
+ input_crop_size,
+ width_factor,
+ depth_factor,
+ bottleneck_factor,
+ ) in [
+ (4, 160, 2.0, 2.2, 2.25),
+ ]:
+ model = create_x3d(
+ input_clip_length=input_clip_length,
+ input_crop_size=input_crop_size,
+ model_num_class=400,
+ dropout_rate=0.5,
+ width_factor=width_factor,
+ depth_factor=depth_factor,
+ norm=nn.BatchNorm3d,
+ activation=nn.ReLU,
+ stem_dim_in=12,
+ stem_conv_kernel_size=(5, 3, 3),
+ stem_conv_stride=(1, 2, 2),
+ stage_conv_kernel_size=((3, 3, 3),) * 4,
+ stage_spatial_stride=(2, 2, 2, 2),
+ stage_temporal_stride=(1, 1, 1, 1),
+ bottleneck=create_x3d_bottleneck_block,
+ bottleneck_factor=bottleneck_factor,
+ se_ratio=0.0625,
+ inner_act=Swish,
+ head_dim_out=2048,
+ head_pool_act=nn.ReLU,
+ head_bn_lin5_on=False,
+ head_activation=nn.Softmax,
+ )
+
+ # Test forwarding.
+ for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
+ if tensor.shape[1] != 3:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ output_shape = out.shape
+ output_shape_gt = (tensor.shape[0], 400)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ def test_load_hubconf(self):
+ path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ )
+ for (input_clip_length, input_crop_size, model_name) in [
+ (4, 160, "x3d_xs"),
+ (13, 160, "x3d_s"),
+ (16, 224, "x3d_m"),
+ ]:
+ model = torch.hub.load(
+ repo_or_dir=path,
+ source="local",
+ model=model_name,
+ pretrained=False,
+ head_output_with_global_average=True,
+ )
+ self.assertIsNotNone(model)
+
+ # Test forwarding.
+ for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
+ if tensor.shape[1] != 3:
+ with self.assertRaises(RuntimeError):
+ out = model(tensor)
+ continue
+
+ out = model(tensor)
+
+ output_shape = out.shape
+ output_shape_gt = (tensor.shape[0], 400)
+
+ self.assertEqual(
+ output_shape,
+ output_shape_gt,
+ "Output shape {} is different from expected shape {}".format(
+ output_shape, output_shape_gt
+ ),
+ )
+
+ @staticmethod
+ def _get_inputs(clip_length: int = 4, crop_size: int = 160) -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ (1, 3, clip_length, crop_size, crop_size),
+ (2, 3, clip_length, crop_size, crop_size),
+ )
+ for shape in shapes:
+ yield torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_simclr.py b/code/pytorchvideo/tests/test_simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..962b54257eb00a0b1b03e6cac5d1bfdf1dd9a2bb
--- /dev/null
+++ b/code/pytorchvideo/tests/test_simclr.py
@@ -0,0 +1,38 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+
+import torch
+from pytorchvideo.models.simclr import SimCLR
+from torch import nn
+
+
+class TestSimCLR(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.set_rng_state(torch.manual_seed(42).get_state())
+
+ def test_simclr(self):
+ simclr = SimCLR(
+ backbone=nn.Linear(8, 4),
+ mlp=nn.Linear(4, 2),
+ temperature=0.07,
+ )
+ for crop1, crop2 in TestSimCLR._get_inputs():
+ simclr(crop1, crop2)
+
+ @staticmethod
+ def _get_inputs() -> torch.tensor:
+ """
+ Provide different tensors as test cases.
+
+ Yield:
+ (torch.tensor): tensor as test case input.
+ """
+ # Prepare random inputs as test cases.
+ shapes = (
+ (1, 8),
+ (2, 8),
+ )
+ for shape in shapes:
+ yield torch.rand(shape), torch.rand(shape)
diff --git a/code/pytorchvideo/tests/test_transforms.py b/code/pytorchvideo/tests/test_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..46a07824b13c81726aa19206512a0a43e89233ce
--- /dev/null
+++ b/code/pytorchvideo/tests/test_transforms.py
@@ -0,0 +1,1229 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import unittest
+from collections import Counter
+from itertools import permutations
+
+import numpy as np
+import torch
+from pytorchvideo.data.utils import thwc_to_cthw
+from pytorchvideo.transforms import (
+ ApplyTransformToKey,
+ AugMix,
+ create_video_transform,
+ CutMix,
+ MixUp,
+ MixVideo,
+ Normalize,
+ OpSampler,
+ Permute,
+ RandAugment,
+ RandomResizedCrop,
+ RandomShortSideScale,
+ ShortSideScale,
+ UniformCropVideo,
+ UniformTemporalSubsample,
+)
+from pytorchvideo.transforms.functional import (
+ clip_boxes_to_image,
+ convert_to_one_hot,
+ div_255,
+ horizontal_flip_with_boxes,
+ random_crop_with_boxes,
+ random_short_side_scale_with_boxes,
+ short_side_scale,
+ short_side_scale_with_boxes,
+ uniform_crop,
+ uniform_crop_with_boxes,
+ uniform_temporal_subsample,
+ uniform_temporal_subsample_repeated,
+)
+from torchvision.transforms import Compose
+from torchvision.transforms._transforms_video import (
+ CenterCropVideo,
+ NormalizeVideo,
+ RandomCropVideo,
+ RandomHorizontalFlipVideo,
+)
+from utils import create_dummy_video_frames, create_random_bbox
+
+
+class TestTransforms(unittest.TestCase):
+ def test_compose_with_video_transforms(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ test_clip = {"video": video, "label": 0}
+
+ # Compose using torchvision and pytorchvideo transformst to ensure they interact
+ # correctly.
+ num_subsample = 10
+ transform = Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(num_subsample),
+ NormalizeVideo([video.mean()] * 3, [video.std()] * 3),
+ RandomShortSideScale(min_size=15, max_size=25),
+ RandomCropVideo(10),
+ RandomHorizontalFlipVideo(p=0.5),
+ ]
+ ),
+ )
+ ]
+ )
+
+ actual = transform(test_clip)
+ c, t, h, w = actual["video"].shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, 10)
+ self.assertEqual(w, 10)
+
+ def test_uniform_temporal_subsample(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ actual = uniform_temporal_subsample(video, video.shape[1])
+ self.assertTrue(actual.equal(video))
+
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ actual = uniform_temporal_subsample(video, video.shape[1] // 2)
+ self.assertTrue(actual.equal(video[:, [0, 2, 4, 6, 8, 10, 12, 14, 16, 19]]))
+
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ actual = uniform_temporal_subsample(video, 1)
+ self.assertTrue(actual.equal(video[:, 0:1]))
+
+ def test_short_side_scale_width_shorter_pytorch(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 20, 10)).to(
+ dtype=torch.float32
+ )
+ actual = short_side_scale(video, 5, backend="pytorch")
+ self.assertEqual(actual.shape, (3, 20, 10, 5))
+
+ def test_short_side_scale_height_shorter_pytorch(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to(
+ dtype=torch.float32
+ )
+ actual = short_side_scale(video, 5, backend="pytorch")
+ self.assertEqual(actual.shape, (3, 20, 5, 10))
+
+ def test_short_side_scale_equal_size_pytorch(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 10, 10)).to(
+ dtype=torch.float32
+ )
+ actual = short_side_scale(video, 10, backend="pytorch")
+ self.assertEqual(actual.shape, (3, 20, 10, 10))
+
+ def test_short_side_scale_width_shorter_opencv(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 20, 10)).to(
+ dtype=torch.float32
+ )
+ actual = short_side_scale(video, 5, backend="opencv")
+ self.assertEqual(actual.shape, (3, 20, 10, 5))
+
+ def test_short_side_scale_height_shorter_opencv(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to(
+ dtype=torch.float32
+ )
+ actual = short_side_scale(video, 5, backend="opencv")
+ self.assertEqual(actual.shape, (3, 20, 5, 10))
+
+ def test_short_side_scale_equal_size_opencv(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 10, 10)).to(
+ dtype=torch.float32
+ )
+ actual = short_side_scale(video, 10, backend="opencv")
+ self.assertEqual(actual.shape, (3, 20, 10, 10))
+
+ def test_random_short_side_scale_height_shorter_pytorch_with_boxes(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to(
+ dtype=torch.float32
+ )
+ boxes = create_random_bbox(7, 10, 20)
+ actual, scaled_boxes = random_short_side_scale_with_boxes(
+ video, min_size=4, max_size=8, backend="pytorch", boxes=boxes
+ )
+ self.assertEqual(actual.shape[0], 3)
+ self.assertEqual(actual.shape[1], 20)
+ self.assertTrue(actual.shape[2] <= 8 and actual.shape[2] >= 4)
+ self._check_boxes(7, actual.shape[2], actual.shape[3], boxes)
+
+ def test_short_side_scale_height_shorter_pytorch_with_boxes(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 10, 20)).to(
+ dtype=torch.float32
+ )
+ boxes = create_random_bbox(7, 10, 20)
+ actual, scaled_boxes = short_side_scale_with_boxes(
+ video,
+ boxes=boxes,
+ size=5,
+ backend="pytorch",
+ )
+ self.assertEqual(actual.shape, (3, 20, 5, 10))
+ self._check_boxes(7, 5, 10, boxes)
+
+ def test_torchscriptable_input_output(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+
+ # Test all the torchscriptable tensors.
+ for transform in [UniformTemporalSubsample(10), RandomShortSideScale(10, 20)]:
+
+ transform_script = torch.jit.script(transform)
+ self.assertTrue(isinstance(transform_script, torch.jit.ScriptModule))
+
+ # Seed before each transform to force determinism.
+ torch.manual_seed(0)
+ output = transform(video)
+ torch.manual_seed(0)
+ script_output = transform_script(video)
+ self.assertTrue(output.equal(script_output))
+
+ def test_uniform_temporal_subsample_repeated(self):
+ video = thwc_to_cthw(create_dummy_video_frames(32, 10, 10)).to(
+ dtype=torch.float32
+ )
+ actual = uniform_temporal_subsample_repeated(video, (1, 4))
+ expected_shape = ((3, 32, 10, 10), (3, 8, 10, 10))
+ for idx in range(len(actual)):
+ self.assertEqual(actual[idx].shape, expected_shape[idx])
+
+ def test_uniform_crop(self):
+ # For videos with height < width.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ # Left crop.
+ actual = uniform_crop(video, size=20, spatial_idx=0)
+ self.assertTrue(actual.equal(video[:, :, 5:25, :20]))
+ # Center crop.
+ actual = uniform_crop(video, size=20, spatial_idx=1)
+ self.assertTrue(actual.equal(video[:, :, 5:25, 10:30]))
+ # Right crop.
+ actual = uniform_crop(video, size=20, spatial_idx=2)
+ self.assertTrue(actual.equal(video[:, :, 5:25, 20:]))
+
+ # For videos with height > width.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 40, 30)).to(
+ dtype=torch.float32
+ )
+ # Top crop.
+ actual = uniform_crop(video, size=20, spatial_idx=0)
+ self.assertTrue(actual.equal(video[:, :, :20, 5:25]))
+ # Center crop.
+ actual = uniform_crop(video, size=20, spatial_idx=1)
+ self.assertTrue(actual.equal(video[:, :, 10:30, 5:25]))
+ # Bottom crop.
+ actual = uniform_crop(video, size=20, spatial_idx=2)
+ self.assertTrue(actual.equal(video[:, :, 20:, 5:25]))
+
+ def test_uniform_crop_with_boxes(self):
+ # For videos with height < width.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ boxes_inp = create_random_bbox(7, 30, 40)
+
+ # Left crop.
+ actual, boxes = uniform_crop_with_boxes(
+ video, size=20, spatial_idx=0, boxes=boxes_inp
+ )
+ self.assertTrue(actual.equal(video[:, :, 5:25, :20]))
+ self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes)
+ # Center crop.
+ actual, boxes = uniform_crop_with_boxes(
+ video, size=20, spatial_idx=1, boxes=boxes_inp
+ )
+ self.assertTrue(actual.equal(video[:, :, 5:25, 10:30]))
+ self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes)
+ # Right crop.
+ actual, boxes = uniform_crop_with_boxes(
+ video, size=20, spatial_idx=2, boxes=boxes_inp
+ )
+ self.assertTrue(actual.equal(video[:, :, 5:25, 20:]))
+ self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes)
+
+ # For videos with height > width.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 40, 30)).to(
+ dtype=torch.float32
+ )
+ # Top crop.
+ actual, boxes = uniform_crop_with_boxes(
+ video, size=20, spatial_idx=0, boxes=boxes_inp
+ )
+ self.assertTrue(actual.equal(video[:, :, :20, 5:25]))
+ self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes)
+ # Center crop.
+ actual, boxes = uniform_crop_with_boxes(
+ video, size=20, spatial_idx=1, boxes=boxes_inp
+ )
+ self.assertTrue(actual.equal(video[:, :, 10:30, 5:25]))
+ self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes)
+ # Bottom crop.
+ actual, boxes = uniform_crop_with_boxes(
+ video, size=20, spatial_idx=2, boxes=boxes_inp
+ )
+ self.assertTrue(actual.equal(video[:, :, 20:, 5:25]))
+ self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes)
+
+ def test_random_crop_with_boxes(self):
+ # For videos with height < width.
+ video = thwc_to_cthw(create_dummy_video_frames(15, 30, 40)).to(
+ dtype=torch.float32
+ )
+ boxes_inp = create_random_bbox(7, 30, 40)
+
+ actual, boxes = random_crop_with_boxes(video, size=20, boxes=boxes_inp)
+ self.assertEqual(actual.shape, (3, 15, 20, 20))
+ self._check_boxes(7, actual.shape[2], actual.shape[3], boxes)
+
+ def test_uniform_crop_transform(self):
+ video = thwc_to_cthw(create_dummy_video_frames(10, 30, 40)).to(
+ dtype=torch.float32
+ )
+ test_clip = {"video": video, "aug_index": 1, "label": 0}
+
+ transform = UniformCropVideo(20)
+
+ actual = transform(test_clip)
+ c, t, h, w = actual["video"].shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 10)
+ self.assertEqual(h, 20)
+ self.assertEqual(w, 20)
+ self.assertTrue(actual["video"].equal(video[:, :, 5:25, 10:30]))
+
+ def test_clip_boxes(self):
+ boxes_inp = create_random_bbox(7, 40, 80)
+ clipped_boxes = clip_boxes_to_image(boxes_inp, 20, 40)
+ self._check_boxes(7, 20, 40, clipped_boxes)
+
+ def test_horizontal_flip_with_boxes(self):
+ video = thwc_to_cthw(create_dummy_video_frames(10, 20, 40)).to(
+ dtype=torch.float32
+ )
+ boxes_inp = create_random_bbox(7, 20, 40)
+
+ actual, boxes = horizontal_flip_with_boxes(0.0, video, boxes_inp)
+ self.assertTrue(actual.equal(video))
+ self.assertTrue(boxes.equal(boxes_inp))
+
+ actual, boxes = horizontal_flip_with_boxes(1.0, video, boxes_inp)
+ self.assertEqual(actual.shape, video.shape)
+ self._check_boxes(7, actual.shape[-2], actual.shape[-1], boxes)
+ self.assertTrue(actual.flip((-1)).equal(video))
+
+ def test_normalize(self):
+ video = thwc_to_cthw(create_dummy_video_frames(10, 30, 40)).to(
+ dtype=torch.float32
+ )
+ transform = Normalize(video.mean(), video.std())
+
+ actual = transform(video)
+ self.assertAlmostEqual(actual.mean().item(), 0)
+ self.assertAlmostEqual(actual.std().item(), 1)
+
+ def test_center_crop(self):
+ video = thwc_to_cthw(create_dummy_video_frames(10, 30, 40)).to(
+ dtype=torch.float32
+ )
+ transform = CenterCropVideo(10)
+
+ actual = transform(video)
+ c, t, h, w = actual.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 10)
+ self.assertEqual(h, 10)
+ self.assertEqual(w, 10)
+ self.assertTrue(actual.equal(video[:, :, 10:20, 15:25]))
+
+ def test_convert_to_one_hot(self):
+ # Test without label smooth.
+ num_class = 5
+ num_samples = 10
+ labels = torch.arange(0, num_samples) % num_class
+ one_hot = convert_to_one_hot(labels, num_class)
+ self.assertEqual(one_hot.sum(), num_samples)
+ label_value = 1.0
+ for index in range(num_samples):
+ label = labels[index]
+
+ self.assertEqual(one_hot[index][label], label_value)
+
+ # Test with label smooth.
+ labels = torch.arange(0, num_samples) % num_class
+ label_smooth = 0.1
+ one_hot_smooth = convert_to_one_hot(
+ labels, num_class, label_smooth=label_smooth
+ )
+ self.assertEqual(one_hot_smooth.sum(), num_samples)
+ label_value_smooth = 1 - label_smooth + label_smooth / num_class
+ for index in range(num_samples):
+ label = labels[index]
+ self.assertEqual(one_hot_smooth[index][label], label_value_smooth)
+
+ def test_OpSampler(self):
+ # Test with weights.
+ n_transform = 3
+ transform_list = [lambda x, i=i: x.fill_(i) for i in range(n_transform)]
+ transform_weight = [1] * n_transform
+ transform = OpSampler(transform_list, transform_weight)
+ input_tensor = torch.rand(1)
+ out_tensor = transform(input_tensor)
+ self.assertTrue(out_tensor.sum() in list(range(n_transform)))
+
+ # Test without weights.
+ input_tensor = torch.rand(1)
+ transform_no_weight = OpSampler(transform_list)
+ out_tensor = transform_no_weight(input_tensor)
+ self.assertTrue(out_tensor.sum() in list(range(n_transform)))
+
+ # Make sure each transform is sampled without replacement.
+ transform_op_values = [3, 5, 7]
+ all_possible_out = [15, 21, 35]
+
+ transform_list = [lambda x, i=i: x * i for i in transform_op_values]
+ test_time = 100
+ transform_no_replacement = OpSampler(transform_list, num_sample_op=2)
+ for _ in range(test_time):
+ input_tensor = torch.ones(1)
+ out_tensor = transform_no_replacement(input_tensor)
+ self.assertTrue(out_tensor.sum() in all_possible_out)
+
+ # Make sure each transform is sampled with replacement.
+ transform_op_values = [3, 5, 7]
+ possible_replacement_out = [9, 25, 49]
+ input_tensor = torch.ones(1)
+ transform_list = [lambda x, i=i: x * i for i in transform_op_values]
+ test_time = 100
+ transform_no_replacement = OpSampler(
+ transform_list, replacement=True, num_sample_op=2
+ )
+ replace_time = 0
+ for _ in range(test_time):
+ input_tensor = torch.ones(1)
+ out_tensor = transform_no_replacement(input_tensor)
+ if out_tensor.sum() in possible_replacement_out:
+ replace_time += 1
+ self.assertTrue(replace_time > 0)
+
+ # Test without weights.
+ transform_op_values = [3.0, 5.0, 7.0]
+ input_tensor = torch.ones(1)
+ transform_list = [lambda x, i=i: x * i for i in transform_op_values]
+ test_time = 10000
+ weights = [10.0, 2.0, 1.0]
+ transform_no_replacement = OpSampler(transform_list, weights)
+ weight_counter = Counter()
+ for _ in range(test_time):
+ input_tensor = torch.ones(1)
+ out_tensor = transform_no_replacement(input_tensor)
+ weight_counter[out_tensor.sum().item()] += 1
+
+ for index, w in enumerate(weights):
+ gt_dis = w / sum(weights)
+ out_key = transform_op_values[index]
+ self.assertTrue(
+ np.allclose(weight_counter[out_key] / test_time, gt_dis, rtol=0.2)
+ )
+
+ def test_mixup(self):
+ # Test images.
+ batch_size = 2
+ h_size = 10
+ w_size = 10
+ c_size = 3
+ input_images = torch.rand(batch_size, c_size, h_size, w_size)
+ input_images[0, :].fill_(0)
+ input_images[1, :].fill_(1)
+ alpha = 1.0
+ label_smoothing = 0.0
+ num_classes = 5
+ transform_mixup = MixUp(
+ alpha=alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_images, mixed_labels = transform_mixup(input_images, labels)
+ gt_image_sum = h_size * w_size * c_size
+ label_sum = batch_size
+
+ self.assertTrue(
+ np.allclose(mixed_images.sum().item(), gt_image_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ # Test videos.
+ batch_size = 2
+ h_size = 10
+ w_size = 10
+ c_size = 3
+ t_size = 2
+ input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size)
+ input_video[0, :].fill_(0)
+ input_video[1, :].fill_(1)
+ alpha = 1.0
+ label_smoothing = 0.0
+ num_classes = 5
+ transform_mixup = MixUp(
+ alpha=alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_videos, mixed_labels = transform_mixup(input_video, labels)
+ gt_video_sum = h_size * w_size * c_size * t_size
+ label_sum = batch_size
+
+ self.assertTrue(
+ np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ # Test videos with label smoothing.
+ input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size)
+ input_video[0, :].fill_(0)
+ input_video[1, :].fill_(1)
+ alpha = 1.0
+ label_smoothing = 0.2
+ num_classes = 5
+ transform_mixup = MixUp(
+ alpha=alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_videos, mixed_labels = transform_mixup(input_video, labels)
+ gt_video_sum = h_size * w_size * c_size * t_size
+ label_sum = batch_size
+ self.assertTrue(
+ np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ # Check the smoothing value is in label.
+ smooth_value = label_smoothing / num_classes
+ self.assertTrue(smooth_value in torch.unique(mixed_labels))
+
+ def test_cutmix(self):
+ torch.manual_seed(0)
+ # Test images.
+ batch_size = 2
+ h_size = 10
+ w_size = 10
+ c_size = 3
+ input_images = torch.rand(batch_size, c_size, h_size, w_size)
+ input_images[0, :].fill_(0)
+ input_images[1, :].fill_(1)
+ alpha = 1.0
+ label_smoothing = 0.0
+ num_classes = 5
+ transform_cutmix = CutMix(
+ alpha=alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_images, mixed_labels = transform_cutmix(input_images, labels)
+ gt_image_sum = h_size * w_size * c_size
+ label_sum = batch_size
+
+ self.assertTrue(
+ np.allclose(mixed_images.sum().item(), gt_image_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ # Test videos.
+ batch_size = 2
+ h_size = 10
+ w_size = 10
+ c_size = 3
+ t_size = 2
+ input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size)
+ input_video[0, :].fill_(0)
+ input_video[1, :].fill_(1)
+ alpha = 1.0
+ label_smoothing = 0.0
+ num_classes = 5
+ transform_cutmix = CutMix(
+ alpha=alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_videos, mixed_labels = transform_cutmix(input_video, labels)
+ gt_video_sum = h_size * w_size * c_size * t_size
+ label_sum = batch_size
+
+ self.assertTrue(
+ np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ # Test videos with label smoothing.
+ input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size)
+ input_video[0, :].fill_(0)
+ input_video[1, :].fill_(1)
+ alpha = 1.0
+ label_smoothing = 0.2
+ num_classes = 5
+ transform_cutmix = CutMix(
+ alpha=alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_videos, mixed_labels = transform_cutmix(input_video, labels)
+ gt_video_sum = h_size * w_size * c_size * t_size
+ label_sum = batch_size
+ self.assertTrue(
+ np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ # Check the smoothing value is in label.
+ smooth_value = label_smoothing / num_classes
+ self.assertTrue(smooth_value in torch.unique(mixed_labels))
+
+ # Check cutmixed video has both 0 and 1.
+ # Run 20 times to avoid rare cases where the random box is empty.
+ test_times = 20
+ seen_all_value1 = False
+ seen_all_value2 = False
+ for _ in range(test_times):
+ mixed_videos, mixed_labels = transform_cutmix(input_video, labels)
+ if 0 in mixed_videos[0, :] and 1 in mixed_videos[0, :]:
+ seen_all_value1 = True
+
+ if 0 in mixed_videos[1, :] and 1 in mixed_videos[1, :]:
+ seen_all_value2 = True
+
+ if seen_all_value1 and seen_all_value2:
+ break
+ self.assertTrue(seen_all_value1)
+ self.assertTrue(seen_all_value2)
+
+ def test_mixvideo(self):
+
+ self.assertRaises(AssertionError, MixVideo, cutmix_prob=2.0)
+
+ torch.manual_seed(0)
+ # Test images.
+ batch_size = 2
+ h_size = 10
+ w_size = 10
+ c_size = 3
+ input_images = torch.rand(batch_size, c_size, h_size, w_size)
+ input_images[0, :].fill_(0)
+ input_images[1, :].fill_(1)
+ mixup_alpha = 1.0
+ cutmix_alpha = 1.0
+ label_smoothing = 0.0
+ num_classes = 5
+ transform_mix = MixVideo(
+ mixup_alpha=mixup_alpha,
+ cutmix_alpha=cutmix_alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_images, mixed_labels = transform_mix(input_images, labels)
+ gt_image_sum = h_size * w_size * c_size
+ label_sum = batch_size
+
+ self.assertTrue(
+ np.allclose(mixed_images.sum().item(), gt_image_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ # Test videos.
+ batch_size = 2
+ h_size = 10
+ w_size = 10
+ c_size = 3
+ t_size = 2
+ input_video = torch.rand(batch_size, c_size, t_size, h_size, w_size)
+ input_video[0, :].fill_(0)
+ input_video[1, :].fill_(1)
+ mixup_alpha = 1.0
+ cutmix_alpha = 1.0
+ label_smoothing = 0.0
+ num_classes = 5
+ transform_mix = MixVideo(
+ mixup_alpha=mixup_alpha,
+ cutmix_alpha=cutmix_alpha,
+ label_smoothing=label_smoothing,
+ num_classes=num_classes,
+ )
+ labels = torch.arange(0, batch_size) % num_classes
+ mixed_videos, mixed_labels = transform_mix(input_video, labels)
+ gt_video_sum = h_size * w_size * c_size * t_size
+ label_sum = batch_size
+
+ self.assertTrue(
+ np.allclose(mixed_videos.sum().item(), gt_video_sum, rtol=0.001)
+ )
+ self.assertTrue(np.allclose(mixed_labels.sum().item(), label_sum, rtol=0.001))
+ self.assertEqual(mixed_labels.size(0), batch_size)
+ self.assertEqual(mixed_labels.size(1), num_classes)
+
+ def _check_boxes(self, num_boxes, height, width, boxes):
+ self.assertEqual(boxes.shape, (num_boxes, 4))
+ self.assertTrue(boxes[:, [0, 2]].min() >= 0 and boxes[:, [0, 2]].max() < width)
+ self.assertTrue(boxes[:, [1, 3]].min() >= 0 and boxes[:, [1, 3]].max() < height)
+
+ def test_randaug(self):
+ # Test default RandAugment.
+ t, c, h, w = 8, 3, 200, 200
+ test_time = 20
+ video_tensor = torch.rand(t, c, h, w)
+ video_rand_aug_fn = RandAugment()
+ for _ in range(test_time):
+ video_tensor_aug = video_rand_aug_fn(video_tensor)
+ self.assertTrue(video_tensor.size() == video_tensor_aug.size())
+ self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype)
+ # Make sure the video is in range.
+ self.assertTrue(video_tensor_aug.max().item() <= 1)
+ self.assertTrue(video_tensor_aug.min().item() >= 0)
+
+ # Test RandAugment with uniform sampling.
+ t, c, h, w = 8, 3, 200, 200
+ test_time = 20
+ video_tensor = torch.rand(t, c, h, w)
+ video_rand_aug_fn = RandAugment(sampling_type="uniform")
+ for _ in range(test_time):
+ video_tensor_aug = video_rand_aug_fn(video_tensor)
+ self.assertTrue(video_tensor.size() == video_tensor_aug.size())
+ self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype)
+ # Make sure the video is in range.
+ self.assertTrue(video_tensor_aug.max().item() <= 1)
+ self.assertTrue(video_tensor_aug.min().item() >= 0)
+
+ # Test if default fill color if found.
+ # Test multiple times due to randomness.
+ t, c, h, w = 8, 3, 200, 200
+ test_time = 40
+ video_tensor = torch.ones(t, c, h, w)
+ video_rand_aug_fn = RandAugment(
+ num_layers=1,
+ prob=1,
+ sampling_type="gaussian",
+ )
+ found_fill_color = 0
+ for _ in range(test_time):
+ video_tensor_aug = video_rand_aug_fn(video_tensor)
+ if 0.5 in video_tensor_aug:
+ found_fill_color += 1
+ self.assertTrue(found_fill_color >= 1)
+
+ def test_random_resized_crop(self):
+ # Test default parameters.
+ crop_size = 10
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+
+ transform = RandomResizedCrop(
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=(0.08, 1.0),
+ aspect_ratio=(3.0 / 4.0, 4.0 / 3.0),
+ )
+
+ video_resized = transform(video)
+ c, t, h, w = video_resized.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 20)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ self.assertEqual(video_resized.dtype, torch.float32)
+
+ # Test reversed parameters.
+ crop_size = 29
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+
+ transform = RandomResizedCrop(
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=(1.8, 0.08),
+ aspect_ratio=(4.0 / 3.0, 3.0 / 4.0),
+ shift=True,
+ )
+
+ video_resized = transform(video)
+ c, t, h, w = video_resized.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 20)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ self.assertEqual(video_resized.dtype, torch.float32)
+
+ # Test one channel.
+ crop_size = 10
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+
+ transform = RandomResizedCrop(
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=(1.8, 1.2),
+ aspect_ratio=(4.0 / 3.0, 3.0 / 4.0),
+ )
+
+ video_resized = transform(video[0:1, :, :, :])
+ c, t, h, w = video_resized.shape
+ self.assertEqual(c, 1)
+ self.assertEqual(t, 20)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ self.assertEqual(video_resized.dtype, torch.float32)
+
+ # Test interpolation.
+ crop_size = 10
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+
+ transform = RandomResizedCrop(
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=(0.08, 1.0),
+ aspect_ratio=(3.0 / 4.0, 4.0 / 3.0),
+ interpolation="bicubic",
+ )
+
+ video_resized = transform(video)
+ c, t, h, w = video_resized.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 20)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ self.assertEqual(video_resized.dtype, torch.float32)
+
+ # Test log_uniform_ratio.
+ crop_size = 10
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+
+ transform = RandomResizedCrop(
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=(0.08, 1.0),
+ aspect_ratio=(3.0 / 4.0, 4.0 / 3.0),
+ log_uniform_ratio=False,
+ )
+
+ video_resized = transform(video)
+ c, t, h, w = video_resized.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 20)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ self.assertEqual(video_resized.dtype, torch.float32)
+
+ def test_augmix(self):
+ # Test default AugMix.
+ t, c, h, w = 8, 3, 200, 200
+ test_time = 20
+ video_tensor = torch.rand(t, c, h, w)
+ video_augmix_fn = AugMix()
+ for _ in range(test_time):
+ video_tensor_aug = video_augmix_fn(video_tensor)
+ self.assertTrue(video_tensor.size() == video_tensor_aug.size())
+ self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype)
+ # Make sure the video is in range.
+ self.assertTrue(video_tensor_aug.max().item() <= 1)
+ self.assertTrue(video_tensor_aug.min().item() >= 0)
+
+ # Test AugMix with non-default parameters.
+ t, c, h, w = 8, 3, 200, 200
+ test_time = 20
+ video_tensor = torch.rand(t, c, h, w)
+ video_augmix_fn = AugMix(magnitude=9, alpha=0.5, width=4, depth=3)
+ for _ in range(test_time):
+ video_tensor_aug = video_augmix_fn(video_tensor)
+ self.assertTrue(video_tensor.size() == video_tensor_aug.size())
+ self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype)
+ # Make sure the video is in range.
+ self.assertTrue(video_tensor_aug.max().item() <= 1)
+ self.assertTrue(video_tensor_aug.min().item() >= 0)
+
+ # Test AugMix with uint8 video.
+ t, c, h, w = 8, 3, 200, 200
+ test_time = 20
+ video_tensor = torch.randint(0, 255, (t, c, h, w)).type(torch.uint8)
+ video_augmix_fn = AugMix(transform_hparas={"fill": (128, 128, 128)})
+ for _ in range(test_time):
+ video_tensor_aug = video_augmix_fn(video_tensor)
+ self.assertTrue(video_tensor.size() == video_tensor_aug.size())
+ self.assertTrue(video_tensor.dtype == video_tensor_aug.dtype)
+ # Make sure the video is in range.
+ self.assertTrue(video_tensor_aug.max().item() <= 255)
+ self.assertTrue(video_tensor_aug.min().item() >= 0)
+
+ # Compare results of AugMix for uint8 and float.
+ t, c, h, w = 8, 3, 200, 200
+ test_time = 40
+ video_tensor_uint8 = torch.randint(0, 255, (t, c, h, w)).type(torch.uint8)
+ video_tensor_float = (video_tensor_uint8 / 255.0).type(torch.float32)
+ video_augmix_fn_uint8 = AugMix(
+ width=1, depth=1, transform_hparas={"fill": (128, 128, 128)}
+ )
+ video_augmix_fn_float = AugMix(width=1, depth=1)
+ for i in range(test_time):
+ torch.set_rng_state(torch.manual_seed(i).get_state())
+ video_tensor_uint8_aug = video_augmix_fn_uint8(video_tensor_uint8)
+ torch.set_rng_state(torch.manual_seed(i).get_state())
+ video_tensor_float_aug = video_augmix_fn_float(video_tensor_float)
+
+ self.assertTrue(
+ torch.mean(
+ torch.abs((video_tensor_uint8_aug / 255.0) - video_tensor_float_aug)
+ )
+ < 0.01
+ )
+
+ self.assertTrue(video_tensor_uint8.size() == video_tensor_uint8_aug.size())
+ self.assertTrue(video_tensor_uint8.dtype == video_tensor_uint8_aug.dtype)
+ self.assertTrue(video_tensor_float.size() == video_tensor_float_aug.size())
+ self.assertTrue(video_tensor_float.dtype == video_tensor_float_aug.dtype)
+ # Make sure the video is in range.
+ self.assertTrue(video_tensor_uint8_aug.max().item() <= 255)
+ self.assertTrue(video_tensor_uint8_aug.min().item() >= 0)
+ self.assertTrue(video_tensor_float_aug.max().item() <= 255)
+ self.assertTrue(video_tensor_float_aug.min().item() >= 0)
+
+ # Test asserts.
+ self.assertRaises(AssertionError, AugMix, magnitude=11)
+ self.assertRaises(AssertionError, AugMix, magnitude=1.1)
+ self.assertRaises(AssertionError, AugMix, alpha=-0.3)
+ self.assertRaises(AssertionError, AugMix, width=0)
+
+ def test_permute(self):
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+
+ for p in list(permutations(range(0, 4))):
+ self.assertTrue(video.permute(*p).equal(Permute(p)(video)))
+
+ def test_video_transform_factory(self):
+ # Test asserts/raises.
+ self.assertRaises(TypeError, create_video_transform, mode="val", crop_size="s")
+ self.assertRaises(
+ AssertionError,
+ create_video_transform,
+ mode="val",
+ crop_size=30,
+ min_size=10,
+ )
+ self.assertRaises(
+ AssertionError,
+ create_video_transform,
+ mode="val",
+ crop_size=(30, 40),
+ min_size=35,
+ )
+ self.assertRaises(
+ AssertionError, create_video_transform, mode="val", remove_key="key"
+ )
+ self.assertRaises(
+ AssertionError,
+ create_video_transform,
+ mode="val",
+ aug_paras={"magnitude": 10},
+ )
+ self.assertRaises(
+ NotImplementedError, create_video_transform, mode="train", aug_type="xyz"
+ )
+
+ # Test train mode.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ test_clip = {"video": video, "audio1": None, "audio2": None, "label": 0}
+
+ num_subsample = 10
+ crop_size = 10
+ transform = create_video_transform(
+ mode="train",
+ num_samples=num_subsample,
+ convert_to_float=False,
+ video_mean=[video.mean()] * 3,
+ video_std=[video.std()] * 3,
+ min_size=15,
+ crop_size=crop_size,
+ )
+ transform_dict = create_video_transform(
+ mode="train",
+ video_key="video",
+ remove_key=["audio1", "audio2"],
+ num_samples=num_subsample,
+ convert_to_float=False,
+ video_mean=[video.mean()] * 3,
+ video_std=[video.std()] * 3,
+ min_size=15,
+ crop_size=crop_size,
+ )
+ transform_frame = create_video_transform(
+ mode="train",
+ num_samples=None,
+ convert_to_float=False,
+ video_mean=[video.mean()] * 3,
+ video_std=[video.std()] * 3,
+ min_size=15,
+ crop_size=crop_size,
+ )
+
+ video_tensor_transformed = transform(video)
+ video_dict_transformed = transform_dict(test_clip)
+ video_frame_transformed = transform_frame(video[:, 0:1, :, :])
+ c, t, h, w = video_tensor_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ c, t, h, w = video_dict_transformed["video"].shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ self.assertFalse("audio1" in video_dict_transformed)
+ self.assertFalse("audio2" in video_dict_transformed)
+ c, t, h, w = video_frame_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 1)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+
+ # Test val mode.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40)).to(
+ dtype=torch.float32
+ )
+ test_clip = {"video": video, "audio": None, "label": 0}
+ test_clip2 = {"video": video, "audio": None, "label": 0}
+
+ num_subsample = 10
+ transform = create_video_transform(
+ mode="val",
+ num_samples=num_subsample,
+ convert_to_float=False,
+ video_mean=[video.mean()] * 3,
+ video_std=[video.std()] * 3,
+ min_size=15,
+ crop_size=crop_size,
+ )
+ transform_dict = create_video_transform(
+ mode="val",
+ video_key="video",
+ num_samples=num_subsample,
+ convert_to_float=False,
+ video_mean=[video.mean()] * 3,
+ video_std=[video.std()] * 3,
+ min_size=15,
+ crop_size=crop_size,
+ )
+ transform_comp = Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(num_subsample),
+ NormalizeVideo([video.mean()] * 3, [video.std()] * 3),
+ ShortSideScale(size=15),
+ CenterCropVideo(crop_size),
+ ]
+ ),
+ )
+ ]
+ )
+ transform_frame = create_video_transform(
+ mode="val",
+ num_samples=None,
+ convert_to_float=False,
+ video_mean=[video.mean()] * 3,
+ video_std=[video.std()] * 3,
+ min_size=15,
+ crop_size=crop_size,
+ )
+
+ video_tensor_transformed = transform(video)
+ video_dict_transformed = transform_dict(test_clip)
+ video_comp_transformed = transform_comp(test_clip2)
+ video_frame_transformed = transform_frame(video[:, 0:1, :, :])
+ self.assertTrue(video_tensor_transformed.equal(video_dict_transformed["video"]))
+ self.assertTrue(
+ video_dict_transformed["video"].equal(video_comp_transformed["video"])
+ )
+ torch.testing.assert_close(
+ video_frame_transformed, video_tensor_transformed[:, 0:1, :, :]
+ )
+ c, t, h, w = video_dict_transformed["video"].shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ self.assertTrue("audio" in video_dict_transformed)
+ c, t, h, w = video_frame_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, 1)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+
+ # Test uint8 video.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40))
+ test_clip = {"video": video, "audio": None, "label": 0}
+
+ transform_uint8 = create_video_transform(
+ mode="val",
+ num_samples=num_subsample,
+ convert_to_float=True,
+ min_size=15,
+ crop_size=crop_size,
+ )
+ transform_float32 = create_video_transform(
+ mode="val",
+ num_samples=num_subsample,
+ convert_to_float=False,
+ min_size=15,
+ crop_size=crop_size,
+ )
+
+ video_uint8_transformed = transform_uint8(video)
+ video_float32_transformed = transform_float32(
+ video.to(dtype=torch.float32) / 255.0
+ )
+ self.assertRaises(
+ AssertionError, transform_uint8, video.to(dtype=torch.float32)
+ )
+ self.assertTrue(video_uint8_transformed.equal(video_float32_transformed))
+ c, t, h, w = video_uint8_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ c, t, h, w = video_float32_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+
+ # Test augmentations.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40))
+
+ transform_randaug = create_video_transform(
+ mode="train",
+ num_samples=num_subsample,
+ min_size=15,
+ crop_size=crop_size,
+ aug_type="randaug",
+ )
+ transform_augmix = create_video_transform(
+ mode="train",
+ num_samples=num_subsample,
+ min_size=15,
+ crop_size=crop_size,
+ aug_type="augmix",
+ )
+ transform_randaug_paras = create_video_transform(
+ mode="train",
+ num_samples=num_subsample,
+ min_size=15,
+ crop_size=crop_size,
+ aug_type="randaug",
+ aug_paras={
+ "magnitude": 8,
+ "num_layers": 3,
+ "prob": 0.7,
+ "sampling_type": "uniform",
+ },
+ )
+ transform_augmix_paras = create_video_transform(
+ mode="train",
+ num_samples=num_subsample,
+ min_size=15,
+ crop_size=crop_size,
+ aug_type="augmix",
+ aug_paras={"magnitude": 5, "alpha": 0.5, "width": 2, "depth": 3},
+ )
+
+ video_randaug_transformed = transform_randaug(video)
+ video_augmix_transformed = transform_augmix(video)
+ video_randaug_paras_transformed = transform_randaug_paras(video)
+ video_augmix_paras_transformed = transform_augmix_paras(video)
+ c, t, h, w = video_randaug_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ c, t, h, w = video_augmix_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ c, t, h, w = video_randaug_paras_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+ c, t, h, w = video_augmix_paras_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+
+ # Test Inception-style cropping.
+ video = thwc_to_cthw(create_dummy_video_frames(20, 30, 40))
+
+ transform_inception = create_video_transform(
+ mode="train",
+ num_samples=num_subsample,
+ min_size=15,
+ crop_size=crop_size,
+ random_resized_crop_paras={},
+ )
+
+ video_inception_transformed = transform_inception(video)
+ c, t, h, w = video_inception_transformed.shape
+ self.assertEqual(c, 3)
+ self.assertEqual(t, num_subsample)
+ self.assertEqual(h, crop_size)
+ self.assertEqual(w, crop_size)
+
+ def test_div_255(self):
+ t, c, h, w = 8, 3, 200, 200
+ video_tensor = torch.rand(t, c, h, w)
+ output_tensor = div_255(video_tensor)
+ expect_tensor = video_tensor / 255
+
+ self.assertEqual(output_tensor.shape, video_tensor.shape)
+ self.assertTrue(bool(torch.all(torch.eq(output_tensor, expect_tensor))))
diff --git a/code/pytorchvideo/tests/test_uniform_clip_sampler.py b/code/pytorchvideo/tests/test_uniform_clip_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b02aa73ad4e4fcf675a92be3fa45e38fa9475dd
--- /dev/null
+++ b/code/pytorchvideo/tests/test_uniform_clip_sampler.py
@@ -0,0 +1,222 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import copy
+import unittest
+from typing import Optional
+
+import numpy as np
+from parameterized import parameterized
+from pytorchvideo.data.clip_sampling import UniformClipSampler
+
+
+def _num_clips(
+ duration_sec: float,
+ fps: float,
+ stride_frames: int,
+ window_size_frames: int,
+ backpad_last: bool = True,
+) -> int:
+ """
+ Utility to calculate the number of clips for a given duration, fps, stride & window_size
+ """
+ num_frames = round(duration_sec * fps)
+ N = num_frames - window_size_frames
+ if N < 0:
+ return 1
+ result = int(N / stride_frames + 1)
+ pad = backpad_last and N % stride_frames != 0
+ return result + pad
+
+
+class TestUniformClipSampler(unittest.TestCase):
+ @parameterized.expand(
+ [
+ (False, 30, 32, 16, 32 / 30, 1),
+ (True, 30, 32, 16, 32 / 30, 1),
+ (True, 30, 32, 16, 33 / 30, 2),
+ (False, 30, 32, 16, 34 / 30, 1),
+ (True, 30, 32, 16, 34 / 30, 2),
+ (False, 30, 32, 16, 47 / 30, 1),
+ (True, 30, 32, 16, 47 / 30, 2),
+ (False, 30, 32, 16, 48 / 30, 2),
+ (True, 30, 32, 16, 48 / 30, 2),
+ (False, 30, 32, 16, 72 / 30, 3),
+ (True, 30, 32, 16, 72 / 30, 4),
+ (False, 30, 32, 16, 109 / 30, 5),
+ (True, 30, 32, 16, 109 / 30, 6),
+ (False, 30, 32, 3, 35 / 30, 2),
+ (True, 30, 32, 3, 35 / 30, 2),
+ (False, 30, 32, 3, 36 / 30, 2),
+ (True, 30, 32, 3, 36 / 30, 3),
+ (True, 30, 32, 3, 35 / 30, 2),
+ (False, 30, 32, 3, 36 / 30, 2),
+ (True, 30, 32, 3, 36 / 30, 3),
+ # no stride => window size
+ (False, 30, 32, 32, 32 / 30, 1),
+ (True, 30, 32, 32, 32 / 30, 1),
+ (False, 30, 32, 32, 54 / 30, 1),
+ (True, 30, 32, 32, 54 / 30, 2),
+ (False, 30, 32, 32, 64 / 30, 2),
+ (True, 30, 32, 32, 64 / 30, 2),
+ # test None for stride
+ (False, 30, 32, None, 64 / 30, 2),
+ (True, 30, 32, None, 64 / 30, 2),
+ # stride = {1, 2}
+ (False, 30, 2, 1, 32 / 30, 31),
+ (True, 30, 2, 1, 32 / 30, 31),
+ # > half stride
+ (False, 30, 32, 24, 107 / 30, 4),
+ (True, 30, 32, 24, 107 / 30, 5),
+ (False, 30, 5, 1, 11 / 30, 7),
+ (True, 30, 5, 1, 11 / 30, 7),
+ # stride > window size
+ (False, 30, 1, 5, 11 / 30, 3),
+ (True, 30, 1, 5, 11 / 30, 3),
+ (True, 30, 1, 5, 1759 / 30, 353),
+ (False, 30, 3, 10, 132 / 30, 13),
+ (True, 30, 3, 10, 132 / 30, 14),
+ (False, 30, 6, 10, 111 / 30, 11),
+ (True, 30, 6, 10, 111 / 30, 12),
+ # stride <= window size
+ (False, 30, 10, 3, 132 / 30, 41),
+ (True, 30, 10, 3, 132 / 30, 42),
+ (False, 30, 10, 6, 111 / 30, 17),
+ (True, 30, 10, 6, 111 / 30, 18),
+ (True, 30, 1, 1, 132 / 30, 132),
+ ]
+ )
+ def test_uniform_clip_sampler(
+ self,
+ backpad_last: bool,
+ fps: int,
+ window_size: int,
+ stride_frames: Optional[int],
+ video_length: float,
+ expected_number_of_clips: int,
+ ):
+ """
+ Utility to test the uniform clip sampler
+ """
+ sampler = UniformClipSampler(
+ window_size / fps,
+ stride_frames / fps if stride_frames is not None else None,
+ backpad_last=backpad_last,
+ )
+ predicted_n_clips = _num_clips(
+ video_length,
+ fps,
+ stride_frames=stride_frames if stride_frames is not None else window_size,
+ window_size_frames=window_size,
+ backpad_last=backpad_last,
+ )
+ self.assertEqual(predicted_n_clips, expected_number_of_clips)
+
+ s_prime = stride_frames if stride_frames is not None else window_size
+ expected_start_end_times = [
+ ((i * s_prime) / fps, ((i * s_prime + window_size) / fps))
+ for i in range(expected_number_of_clips)
+ ]
+ if expected_start_end_times[-1][1] - video_length > 1e-6:
+ expected_start_end_times[-1] = (
+ video_length - window_size / fps,
+ video_length,
+ )
+
+ self.assertTrue(
+ (
+ expected_start_end_times[-1][0] + (s_prime / fps) > video_length
+ or expected_start_end_times[-1][-1] + (s_prime / fps) > video_length
+ )
+ )
+ if len(expected_start_end_times) >= 2:
+ self.assertNotAlmostEqual(
+ expected_start_end_times[-2][0], expected_start_end_times[-1][0]
+ )
+ self.assertNotAlmostEqual(
+ expected_start_end_times[-2][1], expected_start_end_times[-1][1]
+ )
+
+ start_end_times = []
+
+ last_clip_time = None
+ annotation = {}
+ while True:
+ clip = sampler(last_clip_time, video_length, annotation)
+ last_clip_time = copy.deepcopy(clip.clip_end_sec)
+ n_frames = (clip.clip_end_sec - clip.clip_start_sec) * fps
+ int_n_frames = int(np.round(float(n_frames)))
+ self.assertAlmostEqual(float(int_n_frames), float(n_frames))
+ self.assertEqual(int_n_frames, window_size)
+
+ start_end_times.append(
+ (float(clip.clip_start_sec), float(clip.clip_end_sec))
+ )
+ if clip.is_last_clip:
+ break
+
+ # just in case we get an infinite loop
+ if len(start_end_times) > 2 * expected_number_of_clips:
+ break
+
+ self.assertEqual(len(start_end_times), expected_number_of_clips)
+ for (start, end), (expected_start, expected_end) in zip(
+ start_end_times, expected_start_end_times
+ ):
+ self.assertAlmostEqual(float(start), expected_start)
+ self.assertAlmostEqual(float(end), expected_end)
+
+ @parameterized.expand(
+ [
+ (60 / 30, 30, 16, 32, True, 3),
+ (60 / 30, 30, 16, 32, True, 3),
+ (5 / 30, 30, 2, 3, True, 2),
+ (39 / 30, 30, 3, 32, True, 4),
+ (9 / 30, 30, 2, 2, True, 5),
+ (10 / 30, 30, 2, 2, True, 5),
+ (39 / 30, 30, 16, 32, True, 2),
+ (39 / 30, 30, 31, 32, True, 2),
+ (203 / 30, 30, 2, 32, True, 87),
+ (203 / 30, 30, 3, 32, True, 58),
+ (203 / 30, 30, 31, 32, True, 7),
+ (60 / 30, 30, 16, 32, False, 2),
+ (60 / 30, 30, 16, 32, False, 2),
+ (5 / 30, 30, 2, 3, False, 2),
+ (39 / 30, 30, 3, 32, False, 3),
+ (9 / 30, 30, 2, 2, False, 4),
+ (10 / 30, 30, 2, 2, False, 5),
+ (39 / 30, 30, 16, 32, False, 1),
+ (39 / 30, 30, 31, 32, False, 1),
+ (203 / 30, 30, 2, 32, False, 86),
+ (203 / 30, 30, 3, 32, False, 58),
+ (203 / 30, 30, 31, 32, False, 6),
+ (203 / 30, 30, 1, 32, False, 203 - 32 + 1),
+ (19 / 30, 30, 1, 32, False, 1),
+ (19 / 30, 30, 1, 32, True, 1),
+ (33 / 30, 30, 1, 32, False, 2),
+ (33 / 30, 30, 1, 32, True, 2),
+ (11 / 30, 30, 1, 5, False, 7),
+ (11 / 30, 30, 1, 5, True, 7),
+ (11 / 30, 30, 5, 1, False, 3),
+ (11 / 30, 30, 5, 1, True, 3),
+ (1759 / 30, 30, 5, 1, True, 353),
+ ]
+ )
+ def test_num_clips(
+ self,
+ duration_sec: float,
+ fps: int,
+ stride_frames: int,
+ window_size_frames: int,
+ backpad_last: bool,
+ expected: int,
+ ):
+ self.assertEqual(
+ _num_clips(
+ duration_sec, fps, stride_frames, window_size_frames, backpad_last
+ ),
+ expected,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/code/pytorchvideo/tests/utils.py b/code/pytorchvideo/tests/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..707ed75ed6ea1c61b29425579e03e02ada991727
--- /dev/null
+++ b/code/pytorchvideo/tests/utils.py
@@ -0,0 +1,278 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import contextlib
+import os
+import pathlib
+import tempfile
+
+import av
+import numpy as np
+import torch
+import torchvision.io as io
+import torchvision.transforms as transforms
+from pytorchvideo.data.dataset_manifest_utils import (
+ EncodedVideoInfo,
+ VideoFrameInfo,
+ VideoInfo,
+)
+from pytorchvideo.data.utils import thwc_to_cthw
+
+
+def create_dummy_video_frames(num_frames: int, height: int, width: int):
+ y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
+ data = []
+ for i in range(num_frames):
+ xc = float(i) / num_frames
+ yc = 1 - float(i) / (2 * num_frames)
+ d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
+ data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())
+ return torch.stack(data, 0)
+
+
+def create_random_bbox(num_boxes: int, height: int, width: int):
+ bboxes = torch.rand(num_boxes, 4)
+ bboxes[:, 0] *= float(width) / 2.0
+ bboxes[:, 2] = bboxes[:, 2] * float(width) / 2.0 + float(width) / 2.0
+ bboxes[:, 1] *= float(height) / 2.0
+ bboxes[:, 3] = bboxes[:, 3] * float(height) / 2.0 + float(height) / 2.0
+ return torch.floor(bboxes)
+
+
+@contextlib.contextmanager
+def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=None):
+ """
+ Creates a temporary lossless, mp4 video with synthetic content. Uses a context which
+ deletes the video after exit.
+ """
+ # Lossless options.
+ video_codec = "libx264rgb"
+ options = {"crf": "0"}
+ data = create_dummy_video_frames(num_frames, height, width)
+ with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".mp4") as f:
+ f.close()
+ io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
+ yield f.name, thwc_to_cthw(data).to(torch.float32)
+ os.unlink(f.name)
+
+
+@contextlib.contextmanager
+def temp_encoded_video_with_audio(
+ num_frames: int,
+ fps: int,
+ num_audio_samples: int,
+ audio_rate: int = 48000,
+ height=10,
+ width=10,
+ prefix=None,
+):
+ audio_data = torch.from_numpy(np.random.rand(1, num_audio_samples).astype("` (for simple layers) and `pytorchvideo/models/accelerator/` (for complex modules such as residual block). Inferencing of a model built up with corresponding efficient blocks on target device is guranteed to be efficient.\n",
+ "\n",
+ "Each efficient block module is an instance of nn.Module, and has two forms: **original form** (for training) and **deploy form** (for inference). When in original form, the efficient block module has exactly the same behavior as a corresponding vanilla nn.Module for both forward and backward operation. User can freely mix and match efficient blocks for the same target device and build up their own model. Once model is built and trained, user can convert each efficient block in model into deploy form. The conversion will do graph and kernel optimization on each efficient block, and efficient block in deploy form is arithmetically equivalent to original form but has much higher efficiency during inference. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zrK_kLiClgMB"
+ },
+ "source": [
+ "## Design, train and deploy a model composed of efficient blocks for mobile CPU\n",
+ "### Build a model\n",
+ "In this section, let's go through the process of design, train and deploy using a example toy model using efficient blocks under `pytorchvideo/layers/accelerator/mobile_cpu` and `pytorchvideo/models/accelerator/mobile_cpu`, which includes:\n",
+ "- One conv3d head layer with 5x1x1 kernel followed by ReLU activation;\n",
+ "- One residual block with squeeze-excite;\n",
+ "- One average pool and fully connected layer as final output.\n",
+ "\n",
+ "First, let's import efficient blocks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0jg4cZI5lgMC"
+ },
+ "outputs": [],
+ "source": [
+ "# Imports\n",
+ "import torch.nn as nn\n",
+ "from pytorchvideo.layers.accelerator.mobile_cpu.activation_functions import (\n",
+ " supported_act_functions,\n",
+ ")\n",
+ "from pytorchvideo.layers.accelerator.mobile_cpu.convolutions import (\n",
+ " Conv3d5x1x1BnAct,\n",
+ ")\n",
+ "from pytorchvideo.models.accelerator.mobile_cpu.residual_blocks import (\n",
+ " X3dBottleneckBlock,\n",
+ ")\n",
+ "from pytorchvideo.layers.accelerator.mobile_cpu.pool import AdaptiveAvgPool3dOutSize1\n",
+ "from pytorchvideo.layers.accelerator.mobile_cpu.fully_connected import FullyConnected\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MxKCY8TzlgMC"
+ },
+ "source": [
+ "Then we can build a model using those efficient blocks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FYNnTanxlgMD"
+ },
+ "outputs": [],
+ "source": [
+ "class MyNet(nn.Module):\n",
+ " def __init__(\n",
+ " self,\n",
+ " in_channel=3, # input channel of first 5x1x1 layer\n",
+ " residual_block_channel=24, # input channel of residual block\n",
+ " expansion_ratio=3, # expansion ratio of residual block\n",
+ " num_classes=4, # final output classes\n",
+ " ):\n",
+ " super().__init__()\n",
+ " # s1 - 5x1x1 conv3d layer\n",
+ " self.s1 = Conv3d5x1x1BnAct(\n",
+ " in_channel,\n",
+ " residual_block_channel,\n",
+ " bias=False,\n",
+ " groups=1,\n",
+ " use_bn=False,\n",
+ " )\n",
+ " # s2 - residual block\n",
+ " mid_channel = int(residual_block_channel * expansion_ratio)\n",
+ " self.s2 = X3dBottleneckBlock(\n",
+ " in_channels=residual_block_channel,\n",
+ " mid_channels=mid_channel,\n",
+ " out_channels=residual_block_channel,\n",
+ " use_residual=True,\n",
+ " spatial_stride=1,\n",
+ " se_ratio=0.0625,\n",
+ " act_functions=(\"relu\", \"swish\", \"relu\"),\n",
+ " use_bn=(True, True, True),\n",
+ " )\n",
+ " # Average pool and fully connected layer\n",
+ " self.avg_pool = AdaptiveAvgPool3dOutSize1()\n",
+ " self.projection = FullyConnected(residual_block_channel, num_classes, bias=True)\n",
+ " self.act = supported_act_functions['relu']()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.s1(x)\n",
+ " x = self.s2(x)\n",
+ " x = self.avg_pool(x)\n",
+ " # (N, C, T, H, W) -> (N, T, H, W, C).\n",
+ " x = x.permute((0, 2, 3, 4, 1))\n",
+ " x = self.projection(x)\n",
+ " # Performs fully convlutional inference.\n",
+ " if not self.training:\n",
+ " x = self.act(x)\n",
+ " x = x.mean([1, 2, 3])\n",
+ " x = x.view(x.shape[0], -1)\n",
+ "\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fB-_UEHilgMD"
+ },
+ "source": [
+ "We can instantiate MyNet and its efficient blocks will be in original form."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FvXjdqT1lgMD"
+ },
+ "outputs": [],
+ "source": [
+ "net_inst = MyNet()\n",
+ "print(net_inst)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-O6jd3umlgMF"
+ },
+ "source": [
+ "### Train model\n",
+ "Then we can train the model with your dataset/optimizer. Here we skip this training step, and just leave the weight as initial value."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RdNV2EeMlgMF"
+ },
+ "source": [
+ "### Deploy model\n",
+ "Now the model is ready to deploy. First of all, let's convert the model into deploy form. In order to do that, we need to use `convert_to_deployable_form` utility and provide an example input tensor to the model. Note that once the model is converted into deploy form, the input size should be the same as the example input tensor size during conversion."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hA5ER4bLlgMF"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (\n",
+ " convert_to_deployable_form,\n",
+ ")\n",
+ "input_blob_size = (1, 3, 4, 6, 6)\n",
+ "input_tensor = torch.randn(input_blob_size)\n",
+ "net_inst_deploy = convert_to_deployable_form(net_inst, input_tensor)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6FC6knxWlgMG"
+ },
+ "source": [
+ "We can see that the network graph has been changed after conversion, which did kernel and graph optimization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WKXr2Pi1lgMG"
+ },
+ "outputs": [],
+ "source": [
+ "print(net_inst_deploy)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BlA-TZivlgMG"
+ },
+ "source": [
+ "Let's check whether the network after conversion is arithmetically equivalent. We expect the output to be very close before/after conversion, with some small difference due to numeric noise from floating point operation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "I8lsM5oulgMG"
+ },
+ "outputs": [],
+ "source": [
+ "net_inst.eval()\n",
+ "out_ref = net_inst(input_tensor)\n",
+ "out = net_inst_deploy(input_tensor)\n",
+ "\n",
+ "max_err = float(torch.max(torch.abs(out_ref - out)))\n",
+ "print(f\"max error is {max_err}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Yq4c5HeYlgMH"
+ },
+ "source": [
+ "Next we have two options: either deploy floating point model, or quantize model into int8 and then deploy.\n",
+ "\n",
+ "Let's first assume we want to deploy floating point model. In this case, all we need to do is to export jit trace and then apply `optimize_for_mobile` for final optimization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZPX9InColgMH"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.mobile_optimizer import (\n",
+ " optimize_for_mobile,\n",
+ ")\n",
+ "traced_model = torch.jit.trace(net_inst_deploy, input_tensor, strict=False)\n",
+ "traced_model_opt = optimize_for_mobile(traced_model)\n",
+ "# Here we can save the traced_model_opt to JIT file using traced_model_opt.save()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6jFmLo-algMI"
+ },
+ "source": [
+ "Alternatively, we may also want to deploy a quantized model. Efficient blocks are quantization-friendly by design - just wrap the model in deploy form with `QuantStub/DeQuantStub` and it is ready for Pytorch eager mode quantization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "syb-6y2glgMI"
+ },
+ "outputs": [],
+ "source": [
+ "# Wrapper class for adding QuantStub/DeQuantStub.\n",
+ "class quant_stub_wrapper(nn.Module):\n",
+ " def __init__(self, module_in):\n",
+ " super().__init__()\n",
+ " self.quant = torch.quantization.QuantStub()\n",
+ " self.model = module_in\n",
+ " self.dequant = torch.quantization.DeQuantStub()\n",
+ " def forward(self, x):\n",
+ " x = self.quant(x)\n",
+ " x = self.model(x)\n",
+ " x = self.dequant(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yUrtbvo_lgMI"
+ },
+ "outputs": [],
+ "source": [
+ "net_inst_quant_stub_wrapper = quant_stub_wrapper(net_inst_deploy)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qEX2FdcIlgMI"
+ },
+ "source": [
+ "Preparation step of quantization. Fusion has been done for efficient blocks automatically during `convert_to_deployable_form`, so we can just proceed to `torch.quantization.prepare`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "r6DfTh1ElgMI"
+ },
+ "outputs": [],
+ "source": [
+ "net_inst_quant_stub_wrapper.qconfig = torch.quantization.default_qconfig\n",
+ "net_inst_quant_stub_wrapper_prepared = torch.quantization.prepare(net_inst_quant_stub_wrapper)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "q-SkDlVflgMJ"
+ },
+ "source": [
+ "Calibration and quantization. After preparation we will do calibration of quantization by feeding calibration dataset (skipped here) and then do quantization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "E9Zh45yalgMJ"
+ },
+ "outputs": [],
+ "source": [
+ "# calibration is skipped here.\n",
+ "net_inst_quant_stub_wrapper_quantized = torch.quantization.convert(net_inst_quant_stub_wrapper_prepared)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "n1j11-5KlgMJ"
+ },
+ "outputs": [],
+ "source": [
+ "print(net_inst_quant_stub_wrapper_quantized)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7KjWPclrlgMJ"
+ },
+ "source": [
+ "Then we can export trace of int8 model and deploy on mobile devices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "D5YXI4kvlgMK"
+ },
+ "outputs": [],
+ "source": [
+ "traced_model_int8 = torch.jit.trace(net_inst_quant_stub_wrapper_quantized, input_tensor, strict=False)\n",
+ "traced_model_int8_opt = optimize_for_mobile(traced_model_int8)\n",
+ "# Here we can save the traced_model_opt to JIT file using traced_model_int8_opt.save()"
+ ]
+ }
+ ],
+ "metadata": {
+ "bento_stylesheets": {
+ "bento/extensions/flow/main.css": true,
+ "bento/extensions/kernel_selector/main.css": true,
+ "bento/extensions/kernel_ui/main.css": true,
+ "bento/extensions/new_kernel/main.css": true,
+ "bento/extensions/system_usage/main.css": true,
+ "bento/extensions/theme/main.css": true
+ },
+ "colab": {
+ "collapsed_sections": [],
+ "name": "Build your model with PytorchVideo Accelerator.ipynb",
+ "provenance": []
+ },
+ "disseminate_notebook_id": {
+ "notebook_id": "709466976415887"
+ },
+ "disseminate_notebook_info": {
+ "bento_version": "20210314-210430",
+ "description": "PTV tutorial",
+ "hide_code": false,
+ "hipster_group": "",
+ "kernel_build_info": {
+ "error": ""
+ },
+ "no_uii": true,
+ "notebook_number": "512478",
+ "others_can_edit": false,
+ "reviewers": "",
+ "revision_id": "482523946213747",
+ "tags": "",
+ "tasks": "",
+ "title": "Build your model with PytorchVideo Accelerator"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/code/pytorchvideo/tutorials/accelerator/Use_Model_Transmuter.ipynb b/code/pytorchvideo/tutorials/accelerator/Use_Model_Transmuter.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..bdb21095bbb9c32f1f0962961fe7e74f8ba36457
--- /dev/null
+++ b/code/pytorchvideo/tutorials/accelerator/Use_Model_Transmuter.ipynb
@@ -0,0 +1,279 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yOVmvnwW6ism"
+ },
+ "source": [
+ "## Introduction\n",
+ "Got your own model, but still want to fully leverage efficient blocks in PytorchVideo/Accelerator? No problem, model transmuter can help you.\n",
+ "Model transmuter is a utility in PytorchVideo/Accelerator that takes user defined model, and replace modules in user model with equivalent efficient block when possible.\n",
+ "In this tutorial, we will go through typical steps of using model transmuter, including:\n",
+ "- Use model transmuter to replace modules in user model with efficient blocks\n",
+ "- Convert model into deploy form and deploy\n",
+ "\n",
+ "Before we start, let's install PytorchVideo."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2_v3ehr3Bt1T"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install pytorchvideo"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1-RsOLo46iss"
+ },
+ "source": [
+ "## Use model transmuter to replace modules in user model with efficient blocks\n",
+ "First, let's assume user has following model to be transmuted:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ST7sgFdM6ist"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "class user_model_residual_block(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.stem0 = nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))\n",
+ " self.stem1 = nn.Conv3d(3, 3, kernel_size=(5, 1, 1), padding=(2, 0, 0))\n",
+ " self.pw = nn.Conv3d(3, 6, kernel_size=1)\n",
+ " self.relu = nn.ReLU()\n",
+ " self.dw = nn.Conv3d(6, 6, kernel_size=3, padding=1, groups=6)\n",
+ " self.relu1 = nn.ReLU()\n",
+ " self.pwl = nn.Conv3d(6, 3, kernel_size=1)\n",
+ " self.relu2 = nn.ReLU()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " out = self.stem0(x)\n",
+ " out = self.stem1(out)\n",
+ " out = self.pw(out)\n",
+ " out = self.relu(out)\n",
+ " out = self.dw(out)\n",
+ " out = self.relu1(out)\n",
+ " out = self.pwl(out)\n",
+ " return self.relu2(out + x)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f6vbMoE46ist"
+ },
+ "source": [
+ "Then, let's use model transmuter by importing transmuter for targeting device. In this tutorial, we are using mobile cpu as example. Therefore we will import (1) model transmuter for mobile cpu and (2) top-level wrapper of model transmuter."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zi8KsCSh6isu"
+ },
+ "outputs": [],
+ "source": [
+ "import pytorchvideo.accelerator.deployment.mobile_cpu.transmuter # mobile cpu model transmuter\n",
+ "from pytorchvideo.accelerator.deployment.common.model_transmuter import transmute_model # top-level wrapper of model transmuter"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "t4meNp416isu"
+ },
+ "source": [
+ "We instantiate one user_model_residual_block, and transmute it by calling `transmute_model` with argument of `target_device=\"mobile_cpu\"`. We can see that the some of modules in model has been replaced by printing it again. In general, model transmuter will replace one submodule if its equivalent efficient block is found, otherwise that submodule will be kept intact."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "N-YzZp_d6isu"
+ },
+ "outputs": [],
+ "source": [
+ "model_transmute = user_model_residual_block()\n",
+ "print(\"original model\")\n",
+ "print(model_transmute)\n",
+ "transmute_model(\n",
+ " model_transmute,\n",
+ " target_device=\"mobile_cpu\",\n",
+ ")\n",
+ "print(\"after transmute\")\n",
+ "print(model_transmute)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eQi8UFdD6isv"
+ },
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "74G3zWYF6isv"
+ },
+ "source": [
+ "## Convert model into deploy form and deploy\n",
+ "Now the model is ready to deploy. First of all, let's convert the model into deploy form. In order to do that, we need to use `convert_to_deployable_form` utility and provide an example input tensor to the model. `convert_to_deployable_form` will convert any instance of `EfficientBlockBase` (base class for efficient blocks in PytorchVideo/Accelerator) into deploy form, while leave other modules unchanged.\n",
+ "Note that once the model is converted into deploy form, the input size should be the same as the example input tensor size during conversion."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "NCeIb59m6isw"
+ },
+ "outputs": [],
+ "source": [
+ "# Define example input tensor\n",
+ "input_blob_size = (1, 3, 4, 6, 6)\n",
+ "input_tensor = torch.randn(input_blob_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3y3GBWdF6isw"
+ },
+ "outputs": [],
+ "source": [
+ "from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (\n",
+ " convert_to_deployable_form,\n",
+ ")\n",
+ "model_transmute_deploy = convert_to_deployable_form(\n",
+ " model_transmute, input_tensor\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HLt0515O6isw"
+ },
+ "source": [
+ "We can observe further kernel graph change after conversion into deploy mode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7cd1NCew6isw"
+ },
+ "outputs": [],
+ "source": [
+ "print(model_transmute_deploy)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jCRJquGw6isx"
+ },
+ "source": [
+ "Currently model transmuter only supports fp32 operation, and it will support int8 with incoming torch.fx quantization mode. In this tutorial, we assume deploy transmuted model without quantization. In this case, all we need to do is to export jit trace and then apply `optimize_for_mobile` for final optimization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "i2Mr_Il26isx"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.mobile_optimizer import (\n",
+ " optimize_for_mobile,\n",
+ ")\n",
+ "traced_model = torch.jit.trace(model_transmute_deploy, input_tensor, strict=False)\n",
+ "traced_model_opt = optimize_for_mobile(traced_model)\n",
+ "# Here we can save the traced_model_opt to JIT file using traced_model_opt.save()"
+ ]
+ }
+ ],
+ "metadata": {
+ "bento_stylesheets": {
+ "bento/extensions/flow/main.css": true,
+ "bento/extensions/kernel_selector/main.css": true,
+ "bento/extensions/kernel_ui/main.css": true,
+ "bento/extensions/new_kernel/main.css": true,
+ "bento/extensions/system_usage/main.css": true,
+ "bento/extensions/theme/main.css": true
+ },
+ "colab": {
+ "collapsed_sections": [],
+ "name": "Use Model Transmuter.ipynb",
+ "provenance": []
+ },
+ "disseminate_notebook_id": {
+ "notebook_id": "2903671383210410"
+ },
+ "disseminate_notebook_info": {
+ "bento_version": "20210321-210352",
+ "description": "",
+ "hide_code": false,
+ "hipster_group": "",
+ "kernel_build_info": {
+ "error": ""
+ },
+ "no_uii": true,
+ "notebook_number": "520938",
+ "others_can_edit": false,
+ "reviewers": "",
+ "revision_id": "464970858270301",
+ "tags": "",
+ "tasks": "",
+ "title": "Use Model Transmuter"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/code/pytorchvideo/tutorials/accelerator/Use_PytorchVideo_Accelerator_Model_Zoo.ipynb b/code/pytorchvideo/tutorials/accelerator/Use_PytorchVideo_Accelerator_Model_Zoo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..89b2a740c457384ce0df1df25f554c4ec657da1d
--- /dev/null
+++ b/code/pytorchvideo/tutorials/accelerator/Use_PytorchVideo_Accelerator_Model_Zoo.ipynb
@@ -0,0 +1,345 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PV1MwvbCm8X1"
+ },
+ "source": [
+ "## Introduction\n",
+ "This tutorial goes through how to use model zoo provided by PytorchVideo/Accelerator. To use model zoo in PytorchVideo/Accelerator, we should generally follow several steps:\n",
+ "- Use model builder to build selected model; \n",
+ "- Load pretrain checkpoint;\n",
+ "- (Optional) Finetune;\n",
+ "- Deploy.\n",
+ "\n",
+ "Before we start, let's install PytorchVideo."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "h21XJwAKnB8q"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install pytorchvideo"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kppASAd8m8X4"
+ },
+ "source": [
+ "## Use model builder to build selected model\n",
+ "We use model builder in PytorchVideo/Accelerator model zoo to build pre-defined efficient model. Here we use EfficientX3D-XS (for mobile_cpu) as an example. For more available models and details, please refer to [this page].\n",
+ "\n",
+ "EfficientX3D-XS is an implementation of X3D-XS network as described in [X3D paper](https://arxiv.org/abs/2004.04730) using efficient blocks. It is arithmetically equivalent with X3D-XS, but our benchmark on mobile phone shows 4.6X latency reduction compared with vanilla implementation.\n",
+ "\n",
+ "In order to build EfficientX3D-XS, we simply do the following:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VwxiWAbQm8X5"
+ },
+ "outputs": [],
+ "source": [
+ "from pytorchvideo.models.accelerator.mobile_cpu.efficient_x3d import EfficientX3d\n",
+ "model_efficient_x3d_xs = EfficientX3d(expansion='XS', head_act='identity')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uuRnwhYzm8X5"
+ },
+ "source": [
+ "Note that now the efficient blocks in the model are in original form, so the model is good for further training."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RSYnB3p8m8X5"
+ },
+ "source": [
+ "## Load pretrain checkpoint and (optional) finetune\n",
+ "For each model in model zoo, we provide pretrain checkpoint state_dict for model in original form. See [this page] for details about checkpoints and where to download them."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "X9toVl9xm8X6"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.hub import load_state_dict_from_url\n",
+ "checkpoint_path = 'https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/efficient_x3d_xs_original_form.pyth'\n",
+ "checkpoint = load_state_dict_from_url(checkpoint_path)\n",
+ "\n",
+ "model_efficient_x3d_xs.load_state_dict(checkpoint)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cwPUPjJom8X6"
+ },
+ "source": [
+ "Now the model is ready for fine-tune. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jcD6nyVzm8X6"
+ },
+ "source": [
+ "## Deploy\n",
+ "Now the model is ready to deploy. First of all, let's convert the model into deploy form. In order to do that, we need to use `convert_to_deployable_form` utility and provide an example input tensor to the model. Note that once the model is converted into deploy form, the input size should be the same as the example input tensor size during conversion."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2SAavQBZm8X7"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (\n",
+ " convert_to_deployable_form,\n",
+ ")\n",
+ "input_blob_size = (1, 3, 4, 160, 160)\n",
+ "input_tensor = torch.randn(input_blob_size)\n",
+ "model_efficient_x3d_xs_deploy = convert_to_deployable_form(model_efficient_x3d_xs, input_tensor)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ToAwX-2Jm8X7"
+ },
+ "source": [
+ "We can see that the network graph has been changed after conversion, which did kernel and graph optimization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "EWMrKRpim8X7"
+ },
+ "outputs": [],
+ "source": [
+ "print(model_efficient_x3d_xs_deploy)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3HfFgDgCm8X8"
+ },
+ "source": [
+ "Next we have two options: either deploy floating point model, or quantize model into int8 and then deploy.\n",
+ "\n",
+ "Let's first assume we want to deploy floating point model. In this case, all we need to do is to export jit trace and then apply `optimize_for_mobile` for final optimization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "966SbScHm8X9"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.mobile_optimizer import (\n",
+ " optimize_for_mobile,\n",
+ ")\n",
+ "traced_model = torch.jit.trace(model_efficient_x3d_xs_deploy, input_tensor, strict=False)\n",
+ "traced_model_opt = optimize_for_mobile(traced_model)\n",
+ "# Here we can save the traced_model_opt to JIT file using traced_model_opt.save()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Yjaeep9Wm8X9"
+ },
+ "source": [
+ "Alternatively, we may also want to deploy a quantized model. Efficient blocks are quantization-friendly by design - just wrap the model in deploy form with `QuantStub/DeQuantStub` and it is ready for Pytorch eager mode quantization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-cD-OL4km8X9"
+ },
+ "outputs": [],
+ "source": [
+ "import torch.nn as nn\n",
+ "# Wrapper class for adding QuantStub/DeQuantStub.\n",
+ "class quant_stub_wrapper(nn.Module):\n",
+ " def __init__(self, module_in):\n",
+ " super().__init__()\n",
+ " self.quant = torch.quantization.QuantStub()\n",
+ " self.model = module_in\n",
+ " self.dequant = torch.quantization.DeQuantStub()\n",
+ " def forward(self, x):\n",
+ " x = self.quant(x)\n",
+ " x = self.model(x)\n",
+ " x = self.dequant(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "b_-0Kyeym8X-"
+ },
+ "outputs": [],
+ "source": [
+ "model_efficient_x3d_xs_deploy_quant_stub_wrapper = quant_stub_wrapper(model_efficient_x3d_xs_deploy)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "S_rv-Gxcm8YK"
+ },
+ "source": [
+ "Preparation step of quantization. Fusion has been done for efficient blocks automatically during `convert_to_deployable_form`, so we can just proceed to `torch.quantization.prepare`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-kLtF7tpm8YL"
+ },
+ "outputs": [],
+ "source": [
+ "model_efficient_x3d_xs_deploy_quant_stub_wrapper.qconfig = torch.quantization.default_qconfig\n",
+ "model_efficient_x3d_xs_deploy_quant_stub_wrapper_prepared = torch.quantization.prepare(model_efficient_x3d_xs_deploy_quant_stub_wrapper)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2W10VcNwm8YM"
+ },
+ "source": [
+ "Calibration and quantization. After preparation we will do calibration of quantization by feeding calibration dataset (skipped here) and then do quantization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zR2MrKv-m8YM"
+ },
+ "outputs": [],
+ "source": [
+ "# calibration is skipped here.\n",
+ "model_efficient_x3d_xs_deploy_quant_stub_wrapper_quantized = torch.quantization.convert(model_efficient_x3d_xs_deploy_quant_stub_wrapper_prepared)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "87eImwZCm8YM"
+ },
+ "source": [
+ "Then we can export trace of int8 model and deploy on mobile devices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kbN27xw_m8YM"
+ },
+ "outputs": [],
+ "source": [
+ "traced_model_int8 = torch.jit.trace(model_efficient_x3d_xs_deploy_quant_stub_wrapper_quantized, input_tensor, strict=False)\n",
+ "traced_model_int8_opt = optimize_for_mobile(traced_model_int8)\n",
+ "# Here we can save the traced_model_opt to JIT file using traced_model_int8_opt.save()"
+ ]
+ }
+ ],
+ "metadata": {
+ "bento_stylesheets": {
+ "bento/extensions/flow/main.css": true,
+ "bento/extensions/kernel_selector/main.css": true,
+ "bento/extensions/kernel_ui/main.css": true,
+ "bento/extensions/new_kernel/main.css": true,
+ "bento/extensions/system_usage/main.css": true,
+ "bento/extensions/theme/main.css": true
+ },
+ "colab": {
+ "collapsed_sections": [],
+ "name": "Use PytorchVideo Accelerator Model Zoo.ipynb",
+ "provenance": []
+ },
+ "disseminate_notebook_id": {
+ "notebook_id": "478609506614914"
+ },
+ "disseminate_notebook_info": {
+ "bento_version": "20210314-210430",
+ "description": "",
+ "hide_code": false,
+ "hipster_group": "",
+ "kernel_build_info": {
+ "error": ""
+ },
+ "no_uii": true,
+ "notebook_number": "514048",
+ "others_can_edit": false,
+ "reviewers": "",
+ "revision_id": "466653834533727",
+ "tags": "",
+ "tasks": "",
+ "title": "Using PytorchVideo Accelerator Model Zoo"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/code/pytorchvideo/tutorials/torchhub_inference_tutorial.ipynb b/code/pytorchvideo/tutorials/torchhub_inference_tutorial.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..490c547de7311c3e4c7faa5ab99dec1484cbd4a4
--- /dev/null
+++ b/code/pytorchvideo/tutorials/torchhub_inference_tutorial.ipynb
@@ -0,0 +1,327 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Torch Hub Inference Tutorial\n",
+ "\n",
+ "In this tutorial you'll learn:\n",
+ "- how to load a pretrained model using Torch Hub \n",
+ "- run inference to classify the action in a demo video"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install and Import modules"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If `torch`, `torchvision` and `pytorchvideo` are not installed, run the following cell:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " import torch\n",
+ "except ModuleNotFoundError:\n",
+ " !pip install torch torchvision\n",
+ " import os\n",
+ " import sys\n",
+ " import torch\n",
+ " \n",
+ "if torch.__version__=='1.6.0+cu101' and sys.platform.startswith('linux'):\n",
+ " !pip install pytorchvideo\n",
+ "else:\n",
+ " need_pytorchvideo=False\n",
+ " try:\n",
+ " # Running notebook locally\n",
+ " import pytorchvideo\n",
+ " except ModuleNotFoundError:\n",
+ " need_pytorchvideo=True\n",
+ " if need_pytorchvideo:\n",
+ " # Install from GitHub\n",
+ " !pip install \"git+https://github.com/facebookresearch/pytorchvideo.git\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json \n",
+ "from torchvision.transforms import Compose, Lambda\n",
+ "from torchvision.transforms._transforms_video import (\n",
+ " CenterCropVideo,\n",
+ " NormalizeVideo,\n",
+ ")\n",
+ "from pytorchvideo.data.encoded_video import EncodedVideo\n",
+ "from pytorchvideo.transforms import (\n",
+ " ApplyTransformToKey,\n",
+ " ShortSideScale,\n",
+ " UniformTemporalSubsample,\n",
+ " UniformCropVideo\n",
+ ") \n",
+ "from typing import Dict"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Setup \n",
+ "\n",
+ "Download the id to label mapping for the Kinetics 400 dataset on which the Torch Hub models were trained. \n",
+ "This will be used to get the category label names from the predicted class ids."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!wget https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"kinetics_classnames.json\", \"r\") as f:\n",
+ " kinetics_classnames = json.load(f)\n",
+ "\n",
+ "# Create an id to label name mapping\n",
+ "kinetics_id_to_classname = {}\n",
+ "for k, v in kinetics_classnames.items():\n",
+ " kinetics_id_to_classname[v] = str(k).replace('\"', \"\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load Model using Torch Hub API\n",
+ "\n",
+ "PyTorchVideo provides several pretrained models through Torch Hub. Available models are described in [model zoo documentation](https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md#kinetics-400). \n",
+ "\n",
+ "Here we are selecting the `slowfast_r50` model which was trained using a 8x8 setting on the Kinetics 400 dataset. \n",
+ "\n",
+ "\n",
+ "NOTE: to run on GPU in Google Colab, in the menu bar selet: Runtime -> Change runtime type -> Harware Accelerator -> GPU\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Device on which to run the model\n",
+ "# Set to cuda to load on GPU\n",
+ "device = \"cpu\"\n",
+ "\n",
+ "# Pick a pretrained model \n",
+ "model_name = \"slowfast_r50\"\n",
+ "model = torch.hub.load(\"facebookresearch/pytorchvideo:main\", model=model_name, pretrained=True)\n",
+ "\n",
+ "# Set to eval mode and move to desired device\n",
+ "model = model.to(device)\n",
+ "model = model.eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Define the transformations for the input required by the model\n",
+ "\n",
+ "Before passing the video into the model we need to apply some input transforms and sample a clip of the correct duration.\n",
+ "\n",
+ "NOTE: The input transforms are specific to the model. If you choose a different model than the example in this tutorial, please refer to the code provided in the Torch Hub documentation and copy over the relevant transforms:\n",
+ "- [SlowFast](https://pytorch.org/hub/facebookresearch_pytorchvideo_slowfast/)\n",
+ "- [X3D](https://pytorch.org/hub/facebookresearch_pytorchvideo_x3d/)\n",
+ "- [Slow](https://pytorch.org/hub/facebookresearch_pytorchvideo_resnet/)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "####################\n",
+ "# SlowFast transform\n",
+ "####################\n",
+ "\n",
+ "side_size = 256\n",
+ "mean = [0.45, 0.45, 0.45]\n",
+ "std = [0.225, 0.225, 0.225]\n",
+ "crop_size = 256\n",
+ "num_frames = 32\n",
+ "sampling_rate = 2\n",
+ "frames_per_second = 30\n",
+ "alpha = 4\n",
+ "\n",
+ "class PackPathway(torch.nn.Module):\n",
+ " \"\"\"\n",
+ " Transform for converting video frames as a list of tensors. \n",
+ " \"\"\"\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " \n",
+ " def forward(self, frames: torch.Tensor):\n",
+ " fast_pathway = frames\n",
+ " # Perform temporal sampling from the fast pathway.\n",
+ " slow_pathway = torch.index_select(\n",
+ " frames,\n",
+ " 1,\n",
+ " torch.linspace(\n",
+ " 0, frames.shape[1] - 1, frames.shape[1] // alpha\n",
+ " ).long(),\n",
+ " )\n",
+ " frame_list = [slow_pathway, fast_pathway]\n",
+ " return frame_list\n",
+ "\n",
+ "transform = ApplyTransformToKey(\n",
+ " key=\"video\",\n",
+ " transform=Compose(\n",
+ " [\n",
+ " UniformTemporalSubsample(num_frames),\n",
+ " Lambda(lambda x: x/255.0),\n",
+ " NormalizeVideo(mean, std),\n",
+ " ShortSideScale(\n",
+ " size=side_size\n",
+ " ),\n",
+ " CenterCropVideo(crop_size),\n",
+ " PackPathway()\n",
+ " ]\n",
+ " ),\n",
+ ")\n",
+ "\n",
+ "# The duration of the input clip is also specific to the model.\n",
+ "clip_duration = (num_frames * sampling_rate)/frames_per_second"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load an example video\n",
+ "We can test the classification of an example video from the kinetics validation set such as this [archery video](https://www.youtube.com/watch?v=3and4vWkW4s)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Download the example video file\n",
+ "!wget https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4 "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load the example video\n",
+ "video_path = \"archery.mp4\" \n",
+ "\n",
+ "# Select the duration of the clip to load by specifying the start and end duration\n",
+ "# The start_sec should correspond to where the action occurs in the video\n",
+ "start_sec = 0\n",
+ "end_sec = start_sec + clip_duration \n",
+ "\n",
+ "# Initialize an EncodedVideo helper class\n",
+ "video = EncodedVideo.from_path(video_path)\n",
+ "\n",
+ "# Load the desired clip\n",
+ "video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)\n",
+ "\n",
+ "# Apply a transform to normalize the video input\n",
+ "video_data = transform(video_data)\n",
+ "\n",
+ "# Move the inputs to the desired device\n",
+ "inputs = video_data[\"video\"]\n",
+ "inputs = [i.to(device)[None, ...] for i in inputs]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Get model predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Pass the input clip through the model \n",
+ "preds = model(inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the predicted classes \n",
+ "post_act = torch.nn.Softmax(dim=1)\n",
+ "preds = post_act(preds)\n",
+ "pred_classes = preds.topk(k=5).indices\n",
+ "\n",
+ "# Map the predicted classes to the label names\n",
+ "pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes[0]]\n",
+ "print(\"Predicted labels: %s\" % \", \".join(pred_class_names))"
+ ]
+ }
+ ],
+ "metadata": {
+ "bento_stylesheets": {
+ "bento/extensions/flow/main.css": true,
+ "bento/extensions/kernel_selector/main.css": true,
+ "bento/extensions/kernel_ui/main.css": true,
+ "bento/extensions/new_kernel/main.css": true,
+ "bento/extensions/system_usage/main.css": true,
+ "bento/extensions/theme/main.css": true
+ },
+ "kernelspec": {
+ "display_name": "pytorchvideo_etc (local)",
+ "language": "python",
+ "name": "pytorchvideo_etc_local"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/code/pytorchvideo/tutorials/video_classification_example/environment.yml b/code/pytorchvideo/tutorials/video_classification_example/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c9983b82bd636c264f4dbd360c9a6e79de58021f
--- /dev/null
+++ b/code/pytorchvideo/tutorials/video_classification_example/environment.yml
@@ -0,0 +1,12 @@
+# Conda environment file
+# Usage: `conda env update -f environment.yml`
+
+name: video_classification_example
+
+channels:
+ - conda-forge
+ - pytorch-nightly
+
+dependencies:
+ - pytorch-lightning
+ - submitit
diff --git a/code/pytorchvideo/tutorials/video_classification_example/slurm.py b/code/pytorchvideo/tutorials/video_classification_example/slurm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd10ceb300ca07935f97a3da8aa7b273fcd77b25
--- /dev/null
+++ b/code/pytorchvideo/tutorials/video_classification_example/slurm.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import os
+import pathlib
+import shutil
+
+import submitit
+
+
+def init_and_run(run_fn, run_config):
+ os.environ["RANK"] = os.environ["SLURM_LOCALID"]
+ os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
+ os.environ["NODE_RANK"] = os.environ["SLURM_LOCALID"]
+ os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]
+ run_fn(run_config)
+
+
+def copy_and_run_with_config(run_fn, run_config, directory, **cluster_config):
+ working_directory = pathlib.Path(directory) / cluster_config["job_name"]
+ ignore_list = [
+ "lightning_logs",
+ "logs",
+ "checkpoints",
+ "experiments",
+ ".git",
+ "output",
+ "val.csv",
+ "train.csv",
+ ]
+ shutil.copytree(".", working_directory, ignore=lambda x, y: ignore_list)
+ os.chdir(working_directory)
+ print(f"Running at {working_directory}")
+
+ executor = submitit.SlurmExecutor(folder=working_directory)
+ executor.update_parameters(**cluster_config)
+ job = executor.submit(init_and_run, run_fn, run_config)
+ print(f"job_id: {job}")
diff --git a/code/pytorchvideo/tutorials/video_classification_example/train.py b/code/pytorchvideo/tutorials/video_classification_example/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7475702c2d9b63b62d9bdd7e7de8f86ac83b4e
--- /dev/null
+++ b/code/pytorchvideo/tutorials/video_classification_example/train.py
@@ -0,0 +1,465 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import argparse
+import itertools
+import logging
+import os
+
+import pytorch_lightning
+import pytorchvideo.data
+import pytorchvideo.models.resnet
+import torch
+import torch.nn.functional as F
+from pytorch_lightning.callbacks import LearningRateMonitor
+from pytorchvideo.transforms import (
+ ApplyTransformToKey,
+ Normalize,
+ RandomShortSideScale,
+ RemoveKey,
+ ShortSideScale,
+ UniformTemporalSubsample,
+)
+from slurm import copy_and_run_with_config
+from torch.utils.data import DistributedSampler, RandomSampler
+from torchaudio.transforms import MelSpectrogram, Resample
+from torchvision.transforms import (
+ CenterCrop,
+ Compose,
+ Lambda,
+ RandomCrop,
+ RandomHorizontalFlip,
+)
+
+
+"""
+This video classification example demonstrates how PyTorchVideo models, datasets and
+transforms can be used with PyTorch Lightning module. Specifically it shows how a
+simple pipeline to train a Resnet on the Kinetics video dataset can be built.
+
+Don't worry if you don't have PyTorch Lightning experience. We'll provide an explanation
+of how the PyTorch Lightning module works to accompany the example.
+
+The code can be separated into three main components:
+1. VideoClassificationLightningModule (pytorch_lightning.LightningModule), this defines:
+ - how the model is constructed,
+ - the inner train or validation loop (i.e. computing loss/metrics from a minibatch)
+ - optimizer configuration
+
+2. KineticsDataModule (pytorch_lightning.LightningDataModule), this defines:
+ - how to fetch/prepare the dataset
+ - the train and val dataloaders for the associated dataset
+
+3. pytorch_lightning.Trainer, this is a concrete PyTorch Lightning class that provides
+ the training pipeline configuration and a fit(, )
+ function to start the training/validation loop.
+
+All three components are combined in the train() function. We'll explain the rest of the
+details inline.
+"""
+
+
+class VideoClassificationLightningModule(pytorch_lightning.LightningModule):
+ def __init__(self, args):
+ """
+ This LightningModule implementation constructs a PyTorchVideo ResNet,
+ defines the train and val loss to be trained with (cross_entropy), and
+ configures the optimizer.
+ """
+ self.args = args
+ super().__init__()
+ self.train_accuracy = pytorch_lightning.metrics.Accuracy()
+ self.val_accuracy = pytorch_lightning.metrics.Accuracy()
+
+ #############
+ # PTV Model #
+ #############
+
+ # Here we construct the PyTorchVideo model. For this example we're using a
+ # ResNet that works with Kinetics (e.g. 400 num_classes). For your application,
+ # this could be changed to any other PyTorchVideo model (e.g. for SlowFast use
+ # create_slowfast).
+ if self.args.arch == "video_resnet":
+ self.model = pytorchvideo.models.resnet.create_resnet(
+ input_channel=3,
+ model_num_class=400,
+ )
+ self.batch_key = "video"
+ elif self.args.arch == "audio_resnet":
+ self.model = pytorchvideo.models.resnet.create_acoustic_resnet(
+ input_channel=1,
+ model_num_class=400,
+ )
+ self.batch_key = "audio"
+ else:
+ raise Exception("{self.args.arch} not supported")
+
+ def on_train_epoch_start(self):
+ """
+ For distributed training we need to set the datasets video sampler epoch so
+ that shuffling is done correctly
+ """
+ epoch = self.trainer.current_epoch
+ if self.trainer.use_ddp:
+ self.trainer.datamodule.train_dataset.dataset.video_sampler.set_epoch(epoch)
+
+ def forward(self, x):
+ """
+ Forward defines the prediction/inference actions.
+ """
+ return self.model(x)
+
+ def training_step(self, batch, batch_idx):
+ """
+ This function is called in the inner loop of the training epoch. It must
+ return a loss that is used for loss.backwards() internally. The self.log(...)
+ function can be used to log any training metrics.
+
+ PyTorchVideo batches are dictionaries containing each modality or metadata of
+ the batch collated video clips. Kinetics contains the following notable keys:
+ {
+ 'video': ,
+ 'audio': ,
+ 'label': ,
+ }
+
+ - "video" is a Tensor of shape (batch, channels, time, height, Width)
+ - "audio" is a Tensor of shape (batch, channels, time, 1, frequency)
+ - "label" is a Tensor of shape (batch, 1)
+
+ The PyTorchVideo models and transforms expect the same input shapes and
+ dictionary structure making this function just a matter of unwrapping the dict and
+ feeding it through the model/loss.
+ """
+ x = batch[self.batch_key]
+ y_hat = self.model(x)
+ loss = F.cross_entropy(y_hat, batch["label"])
+ acc = self.train_accuracy(F.softmax(y_hat, dim=-1), batch["label"])
+ self.log("train_loss", loss)
+ self.log(
+ "train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True
+ )
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ """
+ This function is called in the inner loop of the evaluation cycle. For this
+ simple example it's mostly the same as the training loop but with a different
+ metric name.
+ """
+ x = batch[self.batch_key]
+ y_hat = self.model(x)
+ loss = F.cross_entropy(y_hat, batch["label"])
+ acc = self.val_accuracy(F.softmax(y_hat, dim=-1), batch["label"])
+ self.log("val_loss", loss)
+ self.log(
+ "val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True
+ )
+ return loss
+
+ def configure_optimizers(self):
+ """
+ We use the SGD optimizer with per step cosine annealing scheduler.
+ """
+ optimizer = torch.optim.SGD(
+ self.parameters(),
+ lr=self.args.lr,
+ momentum=self.args.momentum,
+ weight_decay=self.args.weight_decay,
+ )
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer, self.args.max_epochs, last_epoch=-1
+ )
+ return [optimizer], [scheduler]
+
+
+class KineticsDataModule(pytorch_lightning.LightningDataModule):
+ """
+ This LightningDataModule implementation constructs a PyTorchVideo Kinetics dataset for both
+ the train and val partitions. It defines each partition's augmentation and
+ preprocessing transforms and configures the PyTorch DataLoaders.
+ """
+
+ def __init__(self, args):
+ self.args = args
+ super().__init__()
+
+ def _make_transforms(self, mode: str):
+ """
+ ##################
+ # PTV Transforms #
+ ##################
+
+ # Each PyTorchVideo dataset has a "transform" arg. This arg takes a
+ # Callable[[Dict], Any], and is used on the output Dict of the dataset to
+ # define any application specific processing or augmentation. Transforms can
+ # either be implemented by the user application or reused from any library
+ # that's domain specific to the modality. E.g. for video we recommend using
+ # TorchVision, for audio we recommend TorchAudio.
+ #
+ # To improve interoperation between domain transform libraries, PyTorchVideo
+ # provides a dictionary transform API that provides:
+ # - ApplyTransformToKey(key, transform) - applies a transform to specific modality
+ # - RemoveKey(key) - remove a specific modality from the clip
+ #
+ # In the case that the recommended libraries don't provide transforms that
+ # are common enough for PyTorchVideo use cases, PyTorchVideo will provide them in
+ # the same structure as the recommended library. E.g. TorchVision didn't
+ # have a RandomShortSideScale video transform so it's been added to PyTorchVideo.
+ """
+ if self.args.data_type == "video":
+ transform = [
+ self._video_transform(mode),
+ RemoveKey("audio"),
+ ]
+ elif self.args.data_type == "audio":
+ transform = [
+ self._audio_transform(),
+ RemoveKey("video"),
+ ]
+ else:
+ raise Exception(f"{self.args.data_type} not supported")
+
+ return Compose(transform)
+
+ def _video_transform(self, mode: str):
+ """
+ This function contains example transforms using both PyTorchVideo and TorchVision
+ in the same Callable. For 'train' mode, we use augmentations (prepended with
+ 'Random'), for 'val' mode we use the respective determinstic function.
+ """
+ args = self.args
+ return ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(args.video_num_subsampled),
+ Normalize(args.video_means, args.video_stds),
+ ]
+ + (
+ [
+ RandomShortSideScale(
+ min_size=args.video_min_short_side_scale,
+ max_size=args.video_max_short_side_scale,
+ ),
+ RandomCrop(args.video_crop_size),
+ RandomHorizontalFlip(p=args.video_horizontal_flip_p),
+ ]
+ if mode == "train"
+ else [
+ ShortSideScale(args.video_min_short_side_scale),
+ CenterCrop(args.video_crop_size),
+ ]
+ )
+ ),
+ )
+
+ def _audio_transform(self):
+ """
+ This function contains example transforms using both PyTorchVideo and TorchAudio
+ in the same Callable.
+ """
+ args = self.args
+ n_fft = int(
+ float(args.audio_resampled_rate) / 1000 * args.audio_mel_window_size
+ )
+ hop_length = int(
+ float(args.audio_resampled_rate) / 1000 * args.audio_mel_step_size
+ )
+ eps = 1e-10
+ return ApplyTransformToKey(
+ key="audio",
+ transform=Compose(
+ [
+ Resample(
+ orig_freq=args.audio_raw_sample_rate,
+ new_freq=args.audio_resampled_rate,
+ ),
+ MelSpectrogram(
+ sample_rate=args.audio_resampled_rate,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ n_mels=args.audio_num_mels,
+ center=False,
+ ),
+ Lambda(lambda x: x.clamp(min=eps)),
+ Lambda(torch.log),
+ UniformTemporalSubsample(args.audio_mel_num_subsample),
+ Lambda(lambda x: x.transpose(1, 0)), # (F, T) -> (T, F)
+ Lambda(
+ lambda x: x.view(1, x.size(0), 1, x.size(1))
+ ), # (T, F) -> (1, T, 1, F)
+ Normalize((args.audio_logmel_mean,), (args.audio_logmel_std,)),
+ ]
+ ),
+ )
+
+ def train_dataloader(self):
+ """
+ Defines the train DataLoader that the PyTorch Lightning Trainer trains/tests with.
+ """
+ sampler = DistributedSampler if self.trainer.use_ddp else RandomSampler
+ train_transform = self._make_transforms(mode="train")
+ self.train_dataset = LimitDataset(
+ pytorchvideo.data.Kinetics(
+ data_path=os.path.join(self.args.data_path, "train.csv"),
+ clip_sampler=pytorchvideo.data.make_clip_sampler(
+ "random", self.args.clip_duration
+ ),
+ video_path_prefix=self.args.video_path_prefix,
+ transform=train_transform,
+ video_sampler=sampler,
+ )
+ )
+ return torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=self.args.batch_size,
+ num_workers=self.args.workers,
+ )
+
+ def val_dataloader(self):
+ """
+ Defines the train DataLoader that the PyTorch Lightning Trainer trains/tests with.
+ """
+ sampler = DistributedSampler if self.trainer.use_ddp else RandomSampler
+ val_transform = self._make_transforms(mode="val")
+ self.val_dataset = pytorchvideo.data.Kinetics(
+ data_path=os.path.join(self.args.data_path, "val.csv"),
+ clip_sampler=pytorchvideo.data.make_clip_sampler(
+ "uniform", self.args.clip_duration
+ ),
+ video_path_prefix=self.args.video_path_prefix,
+ transform=val_transform,
+ video_sampler=sampler,
+ )
+ return torch.utils.data.DataLoader(
+ self.val_dataset,
+ batch_size=self.args.batch_size,
+ num_workers=self.args.workers,
+ )
+
+
+class LimitDataset(torch.utils.data.Dataset):
+ """
+ To ensure a constant number of samples are retrieved from the dataset we use this
+ LimitDataset wrapper. This is necessary because several of the underlying videos
+ may be corrupted while fetching or decoding, however, we always want the same
+ number of steps per epoch.
+ """
+
+ def __init__(self, dataset):
+ super().__init__()
+ self.dataset = dataset
+ self.dataset_iter = itertools.chain.from_iterable(
+ itertools.repeat(iter(dataset), 2)
+ )
+
+ def __getitem__(self, index):
+ return next(self.dataset_iter)
+
+ def __len__(self):
+ return self.dataset.num_videos
+
+
+def main():
+ """
+ To train the ResNet with the Kinetics dataset we construct the two modules above,
+ and pass them to the fit function of a pytorch_lightning.Trainer.
+
+ This example can be run either locally (with default parameters) or on a Slurm
+ cluster. To run on a Slurm cluster provide the --on_cluster argument.
+ """
+ setup_logger()
+
+ pytorch_lightning.trainer.seed_everything()
+ parser = argparse.ArgumentParser()
+
+ # Cluster parameters.
+ parser.add_argument("--on_cluster", action="store_true")
+ parser.add_argument("--job_name", default="ptv_video_classification", type=str)
+ parser.add_argument("--working_directory", default=".", type=str)
+ parser.add_argument("--partition", default="dev", type=str)
+
+ # Model parameters.
+ parser.add_argument("--lr", "--learning-rate", default=0.1, type=float)
+ parser.add_argument("--momentum", default=0.9, type=float)
+ parser.add_argument("--weight_decay", default=1e-4, type=float)
+ parser.add_argument(
+ "--arch",
+ default="video_resnet",
+ choices=["video_resnet", "audio_resnet"],
+ type=str,
+ )
+
+ # Data parameters.
+ parser.add_argument("--data_path", default=None, type=str, required=True)
+ parser.add_argument("--video_path_prefix", default="", type=str)
+ parser.add_argument("--workers", default=8, type=int)
+ parser.add_argument("--batch_size", default=32, type=int)
+ parser.add_argument("--clip_duration", default=2, type=float)
+ parser.add_argument(
+ "--data_type", default="video", choices=["video", "audio"], type=str
+ )
+ parser.add_argument("--video_num_subsampled", default=8, type=int)
+ parser.add_argument("--video_means", default=(0.45, 0.45, 0.45), type=tuple)
+ parser.add_argument("--video_stds", default=(0.225, 0.225, 0.225), type=tuple)
+ parser.add_argument("--video_crop_size", default=224, type=int)
+ parser.add_argument("--video_min_short_side_scale", default=256, type=int)
+ parser.add_argument("--video_max_short_side_scale", default=320, type=int)
+ parser.add_argument("--video_horizontal_flip_p", default=0.5, type=float)
+ parser.add_argument("--audio_raw_sample_rate", default=44100, type=int)
+ parser.add_argument("--audio_resampled_rate", default=16000, type=int)
+ parser.add_argument("--audio_mel_window_size", default=32, type=int)
+ parser.add_argument("--audio_mel_step_size", default=16, type=int)
+ parser.add_argument("--audio_num_mels", default=80, type=int)
+ parser.add_argument("--audio_mel_num_subsample", default=128, type=int)
+ parser.add_argument("--audio_logmel_mean", default=-7.03, type=float)
+ parser.add_argument("--audio_logmel_std", default=4.66, type=float)
+
+ # Trainer parameters.
+ parser = pytorch_lightning.Trainer.add_argparse_args(parser)
+ parser.set_defaults(
+ max_epochs=200,
+ callbacks=[LearningRateMonitor()],
+ replace_sampler_ddp=False,
+ )
+
+ # Build trainer, ResNet lightning-module and Kinetics data-module.
+ args = parser.parse_args()
+
+ if args.on_cluster:
+ copy_and_run_with_config(
+ train,
+ args,
+ args.working_directory,
+ job_name=args.job_name,
+ time="72:00:00",
+ partition=args.partition,
+ gpus_per_node=args.gpus,
+ ntasks_per_node=args.gpus,
+ cpus_per_task=10,
+ mem="470GB",
+ nodes=args.num_nodes,
+ constraint="volta32gb",
+ )
+ else: # local
+ train(args)
+
+
+def train(args):
+ trainer = pytorch_lightning.Trainer.from_argparse_args(args)
+ classification_module = VideoClassificationLightningModule(args)
+ data_module = KineticsDataModule(args)
+ trainer.fit(classification_module, data_module)
+
+
+def setup_logger():
+ ch = logging.StreamHandler()
+ formatter = logging.Formatter("\n%(asctime)s [%(levelname)s] %(name)s: %(message)s")
+ ch.setFormatter(formatter)
+ logger = logging.getLogger("pytorchvideo")
+ logger.setLevel(logging.DEBUG)
+ logger.addHandler(ch)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/code/pytorchvideo/tutorials/video_detection_example/video_detection_inference_tutorial.ipynb b/code/pytorchvideo/tutorials/video_detection_example/video_detection_inference_tutorial.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..18d628201e7bde327e362b195d8dbe3fae25547e
--- /dev/null
+++ b/code/pytorchvideo/tutorials/video_detection_example/video_detection_inference_tutorial.ipynb
@@ -0,0 +1,439 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "265ca5a2",
+ "metadata": {},
+ "source": [
+ "# Torch Hub Detection Inference Tutorial\n",
+ "\n",
+ "In this tutorial you'll learn:\n",
+ "- how to load a pretrained detection model using Torch Hub \n",
+ "- run inference to detect actions in a demo video"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e684f3e4",
+ "metadata": {},
+ "source": [
+ "## NOTE: \n",
+ "At the moment tutorial only works if ran on local clone from the directory `pytorchvideo/tutorials/video_detection_example`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1084c2f",
+ "metadata": {},
+ "source": [
+ "### Install and Import modules\n",
+ "If `torch`, `torchvision`, `cv2` and `pytorchvideo` are not installed, run the following cell:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "130e7aaf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " import torch\n",
+ "except ModuleNotFoundError:\n",
+ " !pip install torch torchvision\n",
+ " import os\n",
+ " import sys\n",
+ " import torch\n",
+ "\n",
+ "try:\n",
+ " import cv2\n",
+ "except ModuleNotFoundError:\n",
+ " !pip install opencv-python\n",
+ " \n",
+ "if torch.__version__=='1.6.0+cu101' and sys.platform.startswith('linux'):\n",
+ " !pip install pytorchvideo\n",
+ "else:\n",
+ " need_pytorchvideo=False\n",
+ " try:\n",
+ " # Running notebook locally\n",
+ " import pytorchvideo\n",
+ " except ModuleNotFoundError:\n",
+ " need_pytorchvideo=True\n",
+ " if need_pytorchvideo:\n",
+ " # Install from GitHub\n",
+ " !pip install \"git+https://github.com/facebookresearch/pytorchvideo.git\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "74d4dee2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial\n",
+ "import numpy as np\n",
+ "\n",
+ "import cv2\n",
+ "import torch\n",
+ "\n",
+ "import detectron2\n",
+ "from detectron2.config import get_cfg\n",
+ "from detectron2 import model_zoo\n",
+ "from detectron2.engine import DefaultPredictor\n",
+ "\n",
+ "import pytorchvideo\n",
+ "from pytorchvideo.transforms.functional import (\n",
+ " uniform_temporal_subsample,\n",
+ " short_side_scale_with_boxes,\n",
+ " clip_boxes_to_image,\n",
+ ")\n",
+ "from torchvision.transforms._functional_video import normalize\n",
+ "from pytorchvideo.data.ava import AvaLabeledVideoFramePaths\n",
+ "from pytorchvideo.models.hub import slow_r50_detection # Another option is slowfast_r50_detection\n",
+ "\n",
+ "from visualization import VideoVisualizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b6c8faad",
+ "metadata": {},
+ "source": [
+ "## Load Model using Torch Hub API\n",
+ "PyTorchVideo provides several pretrained models through Torch Hub. Available models are described in [model zoo documentation.](https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md)\n",
+ "\n",
+ "Here we are selecting the slow_r50_detection model which was trained using a 4x16 setting on the Kinetics 400 dataset and \n",
+ "fine tuned on AVA V2.2 actions dataset.\n",
+ "\n",
+ "NOTE: to run on GPU in Google Colab, in the menu bar selet: Runtime -> Change runtime type -> Harware Accelerator -> GPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "6bb9a374",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = 'cuda' # or 'cpu'\n",
+ "video_model = slow_r50_detection(True) # Another option is slowfast_r50_detection\n",
+ "video_model = video_model.eval().to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f21c0ea",
+ "metadata": {},
+ "source": [
+ "## Load an off-the-shelf Detectron2 object detector\n",
+ "\n",
+ "We use the object detector to detect bounding boxes for the people. \n",
+ "These bounding boxes later feed into our video action detection model.\n",
+ "For more details, please refer to the Detectron2's object detection tutorials.\n",
+ "\n",
+ "To install Detectron2, please follow the instructions mentioned [here](https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "7a5d5f4b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cfg = get_cfg()\n",
+ "cfg.merge_from_file(model_zoo.get_config_file(\"COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml\"))\n",
+ "cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.55 # set threshold for this model\n",
+ "cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml\")\n",
+ "predictor = DefaultPredictor(cfg)\n",
+ "\n",
+ "# This method takes in an image and generates the bounding boxes for people in the image.\n",
+ "def get_person_bboxes(inp_img, predictor):\n",
+ " predictions = predictor(inp_img.cpu().detach().numpy())['instances'].to('cpu')\n",
+ " boxes = predictions.pred_boxes if predictions.has(\"pred_boxes\") else None\n",
+ " scores = predictions.scores if predictions.has(\"scores\") else None\n",
+ " classes = np.array(predictions.pred_classes.tolist() if predictions.has(\"pred_classes\") else None)\n",
+ " predicted_boxes = boxes[np.logical_and(classes==0, scores>0.75 )].tensor.cpu() # only person\n",
+ " return predicted_boxes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d8babcba",
+ "metadata": {},
+ "source": [
+ "## Define the transformations for the input required by the model\n",
+ "Before passing the video and bounding boxes into the model we need to apply some input transforms and sample a clip of the correct frame rate in the clip.\n",
+ "\n",
+ "Here, below we define a method that can pre-process the clip and bounding boxes. It generates inputs accordingly for both Slow (Resnet) and SlowFast models depending on the parameterization of the variable `slow_fast_alpha`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "9cb1ec3f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def ava_inference_transform(\n",
+ " clip, \n",
+ " boxes,\n",
+ " num_frames = 4, #if using slowfast_r50_detection, change this to 32\n",
+ " crop_size = 256, \n",
+ " data_mean = [0.45, 0.45, 0.45], \n",
+ " data_std = [0.225, 0.225, 0.225],\n",
+ " slow_fast_alpha = None, #if using slowfast_r50_detection, change this to 4\n",
+ "):\n",
+ "\n",
+ " boxes = np.array(boxes)\n",
+ " ori_boxes = boxes.copy()\n",
+ "\n",
+ " # Image [0, 255] -> [0, 1].\n",
+ " clip = uniform_temporal_subsample(clip, num_frames)\n",
+ " clip = clip.float()\n",
+ " clip = clip / 255.0\n",
+ "\n",
+ " height, width = clip.shape[2], clip.shape[3]\n",
+ " # The format of boxes is [x1, y1, x2, y2]. The input boxes are in the\n",
+ " # range of [0, width] for x and [0,height] for y\n",
+ " boxes = clip_boxes_to_image(boxes, height, width)\n",
+ "\n",
+ " # Resize short side to crop_size. Non-local and STRG uses 256.\n",
+ " clip, boxes = short_side_scale_with_boxes(\n",
+ " clip,\n",
+ " size=crop_size,\n",
+ " boxes=boxes,\n",
+ " )\n",
+ " \n",
+ " # Normalize images by mean and std.\n",
+ " clip = normalize(\n",
+ " clip,\n",
+ " np.array(data_mean, dtype=np.float32),\n",
+ " np.array(data_std, dtype=np.float32),\n",
+ " )\n",
+ " \n",
+ " boxes = clip_boxes_to_image(\n",
+ " boxes, clip.shape[2], clip.shape[3]\n",
+ " )\n",
+ " \n",
+ " # Incase of slowfast, generate both pathways\n",
+ " if slow_fast_alpha is not None:\n",
+ " fast_pathway = clip\n",
+ " # Perform temporal sampling from the fast pathway.\n",
+ " slow_pathway = torch.index_select(\n",
+ " clip,\n",
+ " 1,\n",
+ " torch.linspace(\n",
+ " 0, clip.shape[1] - 1, clip.shape[1] // slow_fast_alpha\n",
+ " ).long(),\n",
+ " )\n",
+ " clip = [slow_pathway, fast_pathway]\n",
+ " \n",
+ " return clip, torch.from_numpy(boxes), ori_boxes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f26315c",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "\n",
+ "Download the id to label mapping for the AVA V2.2 dataset on which the Torch Hub models were finetuned. \n",
+ "This will be used to get the category label names from the predicted class ids.\n",
+ "\n",
+ "Create a visualizer to visualize and plot the results(labels + bounding boxes)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "6132a777",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!wget https://dl.fbaipublicfiles.com/pytorchvideo/data/class_names/ava_action_list.pbtxt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "39454172",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create an id to label name mapping\n",
+ "label_map, allowed_class_ids = AvaLabeledVideoFramePaths.read_label_map('ava_action_list.pbtxt')\n",
+ "# Create a video visualizer that can plot bounding boxes and visualize actions on bboxes.\n",
+ "video_visualizer = VideoVisualizer(81, label_map, top_k=3, mode=\"thres\",thres=0.5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7086f4a0",
+ "metadata": {},
+ "source": [
+ "## Load an example video\n",
+ "We get an opensourced video off the web from WikiMedia. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "f27c302c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!wget https://dl.fbaipublicfiles.com/pytorchvideo/projects/theatre.webm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "b8bcc454",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Completed loading encoded video.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Load the video\n",
+ "encoded_vid = pytorchvideo.data.encoded_video.EncodedVideo.from_path('theatre.webm')\n",
+ "print('Completed loading encoded video.')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3edb57ca",
+ "metadata": {},
+ "source": [
+ "## Generate bounding boxes and action predictions for a 10 second clip in the video."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "500ebdfb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Video predictions are generated at an internal of 1 sec from 90 seconds to 100 seconds in the video.\n",
+ "time_stamp_range = range(90,100) # time stamps in video for which clip is sampled. \n",
+ "clip_duration = 1.0 # Duration of clip used for each inference step.\n",
+ "gif_imgs = []\n",
+ "\n",
+ "for time_stamp in time_stamp_range: \n",
+ " print(\"Generating predictions for time stamp: {} sec\".format(time_stamp))\n",
+ " \n",
+ " # Generate clip around the designated time stamps\n",
+ " inp_imgs = encoded_vid.get_clip(\n",
+ " time_stamp - clip_duration/2.0, # start second\n",
+ " time_stamp + clip_duration/2.0 # end second\n",
+ " )\n",
+ " inp_imgs = inp_imgs['video']\n",
+ " \n",
+ " # Generate people bbox predictions using Detectron2's off the self pre-trained predictor\n",
+ " # We use the the middle image in each clip to generate the bounding boxes.\n",
+ " inp_img = inp_imgs[:,inp_imgs.shape[1]//2,:,:]\n",
+ " inp_img = inp_img.permute(1,2,0)\n",
+ " \n",
+ " # Predicted boxes are of the form List[(x_1, y_1, x_2, y_2)]\n",
+ " predicted_boxes = get_person_bboxes(inp_img, predictor) \n",
+ " if len(predicted_boxes) == 0: \n",
+ " print(\"Skipping clip no frames detected at time stamp: \", time_stamp)\n",
+ " continue\n",
+ " \n",
+ " # Preprocess clip and bounding boxes for video action recognition.\n",
+ " inputs, inp_boxes, _ = ava_inference_transform(inp_imgs, predicted_boxes.numpy())\n",
+ " # Prepend data sample id for each bounding box. \n",
+ " # For more details refere to the RoIAlign in Detectron2\n",
+ " inp_boxes = torch.cat([torch.zeros(inp_boxes.shape[0],1), inp_boxes], dim=1)\n",
+ " \n",
+ " # Generate actions predictions for the bounding boxes in the clip.\n",
+ " # The model here takes in the pre-processed video clip and the detected bounding boxes.\n",
+ " if isinstance(inputs, list):\n",
+ " inputs = [inp.unsqueeze(0).to(device) for inp in inputs]\n",
+ " else:\n",
+ " inputs = inputs.unsqueeze(0).to(device)\n",
+ " preds = video_model(inputs, inp_boxes.to(device))\n",
+ "\n",
+ " preds= preds.to('cpu')\n",
+ " # The model is trained on AVA and AVA labels are 1 indexed so, prepend 0 to convert to 0 index.\n",
+ " preds = torch.cat([torch.zeros(preds.shape[0],1), preds], dim=1)\n",
+ " \n",
+ " # Plot predictions on the video and save for later visualization.\n",
+ " inp_imgs = inp_imgs.permute(1,2,3,0)\n",
+ " inp_imgs = inp_imgs/255.0\n",
+ " out_img_pred = video_visualizer.draw_clip_range(inp_imgs, preds, predicted_boxes)\n",
+ " gif_imgs += out_img_pred\n",
+ "\n",
+ "print(\"Finished generating predictions.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8da24031",
+ "metadata": {},
+ "source": [
+ "## Save predictions as video\n",
+ "The generated video consists of bounding boxes with predicted actions for each bounding box."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "c4ae73fe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "height, width = gif_imgs[0].shape[0], gif_imgs[0].shape[1]\n",
+ "\n",
+ "vide_save_path = 'output_detections.mp4'\n",
+ "video = cv2.VideoWriter(vide_save_path,cv2.VideoWriter_fourcc(*'DIVX'), 7, (width,height))\n",
+ "\n",
+ "for image in gif_imgs:\n",
+ " img = (255*image).astype(np.uint8)\n",
+ " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
+ " video.write(img)\n",
+ "video.release()\n",
+ "\n",
+ "print('Predictions are saved to the video file: ', vide_save_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d0d1e754",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/code/pytorchvideo/tutorials/video_detection_example/visualization.py b/code/pytorchvideo/tutorials/video_detection_example/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..c090076128b89979a4ddf06f7861a9364f7bec50
--- /dev/null
+++ b/code/pytorchvideo/tutorials/video_detection_example/visualization.py
@@ -0,0 +1,708 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# Note: This file has been barrowed from facebookresearch/slowfast repo.
+# TODO: Migrate this into the core PyTorchVideo libarary.
+
+from __future__ import annotations
+
+import itertools
+import logging
+from types import SimpleNamespace
+from typing import Dict, List, Optional, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from detectron2.utils.visualizer import Visualizer
+
+
+logger = logging.getLogger(__name__)
+
+
+def _create_text_labels(
+ classes: List[int],
+ scores: List[float],
+ class_names: List[str],
+ ground_truth: bool = False,
+) -> List[str]:
+ """
+ Create text labels.
+ Args:
+ classes (list[int]): a list of class ids for each example.
+ scores (list[float] or None): list of scores for each example.
+ class_names (list[str]): a list of class names, ordered by their ids.
+ ground_truth (bool): whether the labels are ground truth.
+ Returns:
+ labels (list[str]): formatted text labels.
+ """
+ try:
+ labels = [class_names.get(c, "n/a") for c in classes]
+ except IndexError:
+ logger.error("Class indices get out of range: {}".format(classes))
+ return None
+
+ if ground_truth:
+ labels = ["[{}] {}".format("GT", label) for label in labels]
+ elif scores is not None:
+ assert len(classes) == len(scores)
+ labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)]
+ return labels
+
+
+class ImgVisualizer(Visualizer):
+ def __init__(
+ self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs
+ ) -> None:
+ """
+ See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
+ for more details.
+ Args:
+ img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to
+ the height and width of the image respectively. C is the number of
+ color channels. The image is required to be in RGB format since that
+ is a requirement of the Matplotlib library. The image is also expected
+ to be in the range [0, 255].
+ meta (MetadataCatalog): image metadata.
+ See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90
+ """
+ super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs)
+
+ def draw_text(
+ self,
+ text: str,
+ position: List[int],
+ *,
+ font_size: Optional[int] = None,
+ color: str = "w",
+ horizontal_alignment: str = "center",
+ vertical_alignment: str = "bottom",
+ box_facecolor: str = "black",
+ alpha: float = 0.5,
+ ) -> None:
+ """
+ Draw text at the specified position.
+ Args:
+ text (str): the text to draw on image.
+ position (list of 2 ints): the x,y coordinate to place the text.
+ font_size (Optional[int]): font of the text. If not provided, a font size
+ proportional to the image width is calculated and used.
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
+ of formats that are accepted.
+ horizontal_alignment (str): see `matplotlib.text.Text`.
+ vertical_alignment (str): see `matplotlib.text.Text`.
+ box_facecolor (str): color of the box wrapped around the text. Refer to
+ `matplotlib.colors` for full list of formats that are accepted.
+ alpha (float): transparency level of the box.
+ """
+ if not font_size:
+ font_size = self._default_font_size
+ x, y = position
+ self.output.ax.text(
+ x,
+ y,
+ text,
+ size=font_size * self.output.scale,
+ family="monospace",
+ bbox={
+ "facecolor": box_facecolor,
+ "alpha": alpha,
+ "pad": 0.7,
+ "edgecolor": "none",
+ },
+ verticalalignment=vertical_alignment,
+ horizontalalignment=horizontal_alignment,
+ color=color,
+ zorder=10,
+ )
+
+ def draw_multiple_text(
+ self,
+ text_ls: List[str],
+ box_coordinate: torch.Tensor,
+ *,
+ top_corner: bool = True,
+ font_size: Optional[int] = None,
+ color: str = "w",
+ box_facecolors: str = "black",
+ alpha: float = 0.5,
+ ) -> None:
+ """
+ Draw a list of text labels for some bounding box on the image.
+ Args:
+ text_ls (list of strings): a list of text labels.
+ box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
+ coordinates of the box.
+ top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box.
+ Else, draw labels at (x_left, y_bottom).
+ font_size (Optional[int]): font of the text. If not provided, a font size
+ proportional to the image width is calculated and used.
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
+ of formats that are accepted.
+ box_facecolors (str): colors of the box wrapped around the text. Refer to
+ `matplotlib.colors` for full list of formats that are accepted.
+ alpha (float): transparency level of the box.
+ """
+ if not isinstance(box_facecolors, list):
+ box_facecolors = [box_facecolors] * len(text_ls)
+ assert len(box_facecolors) == len(
+ text_ls
+ ), "Number of colors provided is not equal to the number of text labels."
+ if not font_size:
+ font_size = self._default_font_size
+ text_box_width = font_size + font_size // 2
+ # If the texts does not fit in the assigned location,
+ # we split the text and draw it in another place.
+ if top_corner:
+ num_text_split = self._align_y_top(
+ box_coordinate, len(text_ls), text_box_width
+ )
+ y_corner = 1
+ else:
+ num_text_split = len(text_ls) - self._align_y_bottom(
+ box_coordinate, len(text_ls), text_box_width
+ )
+ y_corner = 3
+
+ text_color_sorted = sorted(
+ zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True
+ )
+ if len(text_color_sorted) != 0:
+ text_ls, box_facecolors = zip(*text_color_sorted)
+ else:
+ text_ls, box_facecolors = [], []
+ text_ls, box_facecolors = list(text_ls), list(box_facecolors)
+ self.draw_multiple_text_upward(
+ text_ls[:num_text_split][::-1],
+ box_coordinate,
+ y_corner=y_corner,
+ font_size=font_size,
+ color=color,
+ box_facecolors=box_facecolors[:num_text_split][::-1],
+ alpha=alpha,
+ )
+ self.draw_multiple_text_downward(
+ text_ls[num_text_split:],
+ box_coordinate,
+ y_corner=y_corner,
+ font_size=font_size,
+ color=color,
+ box_facecolors=box_facecolors[num_text_split:],
+ alpha=alpha,
+ )
+
+ def draw_multiple_text_upward(
+ self,
+ text_ls: List[str],
+ box_coordinate: torch.Tensor,
+ *,
+ y_corner: int = 1,
+ font_size: Optional[int] = None,
+ color: str = "w",
+ box_facecolors: str = "black",
+ alpha: float = 0.5,
+ ) -> None:
+ """
+ Draw a list of text labels for some bounding box on the image in upward direction.
+ The next text label will be on top of the previous one.
+ Args:
+ text_ls (list of strings): a list of text labels.
+ box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
+ coordinates of the box.
+ y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
+ the box to draw labels around.
+ font_size (Optional[int]): font of the text. If not provided, a font size
+ proportional to the image width is calculated and used.
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
+ of formats that are accepted.
+ box_facecolors (str or list of strs): colors of the box wrapped around the
+ text. Refer to `matplotlib.colors` for full list of formats that
+ are accepted.
+ alpha (float): transparency level of the box.
+ """
+ if not isinstance(box_facecolors, list):
+ box_facecolors = [box_facecolors] * len(text_ls)
+ assert len(box_facecolors) == len(
+ text_ls
+ ), "Number of colors provided is not equal to the number of text labels."
+
+ assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
+ if not font_size:
+ font_size = self._default_font_size
+
+ x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
+ y = box_coordinate[y_corner].item()
+ for i, text in enumerate(text_ls):
+ self.draw_text(
+ text,
+ (x, y),
+ font_size=font_size,
+ color=color,
+ horizontal_alignment=horizontal_alignment,
+ vertical_alignment="bottom",
+ box_facecolor=box_facecolors[i],
+ alpha=alpha,
+ )
+ y -= font_size + font_size // 2
+
+ def draw_multiple_text_downward(
+ self,
+ text_ls: List[str],
+ box_coordinate: torch.Tensor,
+ *,
+ y_corner: int = 1,
+ font_size: Optional[int] = None,
+ color: str = "w",
+ box_facecolors: str = "black",
+ alpha: float = 0.5,
+ ) -> None:
+ """
+ Draw a list of text labels for some bounding box on the image in downward direction.
+ The next text label will be below the previous one.
+ Args:
+ text_ls (list of strings): a list of text labels.
+ box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
+ coordinates of the box.
+ y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
+ the box to draw labels around.
+ font_size (Optional[int]): font of the text. If not provided, a font size
+ proportional to the image width is calculated and used.
+ color (str): color of the text. Refer to `matplotlib.colors` for full list
+ of formats that are accepted.
+ box_facecolors (str): colors of the box wrapped around the text. Refer to
+ `matplotlib.colors` for full list of formats that are accepted.
+ alpha (float): transparency level of the box.
+ """
+ if not isinstance(box_facecolors, list):
+ box_facecolors = [box_facecolors] * len(text_ls)
+ assert len(box_facecolors) == len(
+ text_ls
+ ), "Number of colors provided is not equal to the number of text labels."
+
+ assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
+ if not font_size:
+ font_size = self._default_font_size
+
+ x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
+ y = box_coordinate[y_corner].item()
+ for i, text in enumerate(text_ls):
+ self.draw_text(
+ text,
+ (x, y),
+ font_size=font_size,
+ color=color,
+ horizontal_alignment=horizontal_alignment,
+ vertical_alignment="top",
+ box_facecolor=box_facecolors[i],
+ alpha=alpha,
+ )
+ y += font_size + font_size // 2
+
+ def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]:
+ """
+ Choose an x-coordinate from the box to make sure the text label
+ does not go out of frames. By default, the left x-coordinate is
+ chosen and text is aligned left. If the box is too close to the
+ right side of the image, then the right x-coordinate is chosen
+ instead and the text is aligned right.
+ Args:
+ box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
+ coordinates of the box.
+ Returns:
+ x_coordinate (float): the chosen x-coordinate.
+ alignment (str): whether to align left or right.
+ """
+ # If the x-coordinate is greater than 5/6 of the image width,
+ # then we align test to the right of the box. This is
+ # chosen by heuristics.
+ if box_coordinate[0] > (self.output.width * 5) // 6:
+ return box_coordinate[2], "right"
+
+ return box_coordinate[0], "left"
+
+ def _align_y_top(
+ self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
+ ) -> int:
+ """
+ Calculate the number of text labels to plot on top of the box
+ without going out of frames.
+ Args:
+ box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
+ coordinates of the box.
+ num_text (int): the number of text labels to plot.
+ textbox_width (float): the width of the box wrapped around text label.
+ """
+ dist_to_top = box_coordinate[1]
+ num_text_top = dist_to_top // textbox_width
+
+ if isinstance(num_text_top, torch.Tensor):
+ num_text_top = int(num_text_top.item())
+
+ return min(num_text, num_text_top)
+
+ def _align_y_bottom(
+ self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
+ ) -> int:
+ """
+ Calculate the number of text labels to plot at the bottom of the box
+ without going out of frames.
+ Args:
+ box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
+ coordinates of the box.
+ num_text (int): the number of text labels to plot.
+ textbox_width (float): the width of the box wrapped around text label.
+ """
+ dist_to_bottom = self.output.height - box_coordinate[3]
+ num_text_bottom = dist_to_bottom // textbox_width
+
+ if isinstance(num_text_bottom, torch.Tensor):
+ num_text_bottom = int(num_text_bottom.item())
+
+ return min(num_text, num_text_bottom)
+
+
+class VideoVisualizer:
+ def __init__(
+ self,
+ num_classes: int,
+ class_names: Dict,
+ top_k: int = 1,
+ colormap: str = "rainbow",
+ thres: float = 0.7,
+ lower_thres: float = 0.3,
+ common_class_names: Optional[List[str]] = None,
+ mode: str = "top-k",
+ ) -> None:
+ """
+ Args:
+ num_classes (int): total number of classes.
+ class_names (dict): Dict mapping classID to name.
+ top_k (int): number of top predicted classes to plot.
+ colormap (str): the colormap to choose color for class labels from.
+ See https://matplotlib.org/tutorials/colors/colormaps.html
+ thres (float): threshold for picking predicted classes to visualize.
+ lower_thres (Optional[float]): If `common_class_names` if given,
+ this `lower_thres` will be applied to uncommon classes and
+ `thres` will be applied to classes in `common_class_names`.
+ common_class_names (Optional[list of str]): list of common class names
+ to apply `thres`. Class names not included in `common_class_names` will
+ have `lower_thres` as a threshold. If None, all classes will have
+ `thres` as a threshold. This is helpful for model trained on
+ highly imbalanced dataset.
+ mode (str): Supported modes are {"top-k", "thres"}.
+ This is used for choosing predictions for visualization.
+
+ """
+ assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode)
+ self.mode = mode
+ self.num_classes = num_classes
+ self.class_names = class_names
+ self.top_k = top_k
+ self.thres = thres
+ self.lower_thres = lower_thres
+
+ if mode == "thres":
+ self._get_thres_array(common_class_names=common_class_names)
+
+ self.color_map = plt.get_cmap(colormap)
+
+ def _get_color(self, class_id: int) -> List[float]:
+ """
+ Get color for a class id.
+ Args:
+ class_id (int): class id.
+ """
+ return self.color_map(class_id / self.num_classes)[:3]
+
+ def draw_one_frame(
+ self,
+ frame: Union[torch.Tensor, np.ndarray],
+ preds: Union[torch.Tensor, List[float]],
+ bboxes: Optional[torch.Tensor] = None,
+ alpha: float = 0.5,
+ text_alpha: float = 0.7,
+ ground_truth: bool = False,
+ ) -> np.ndarray:
+ """
+ Draw labels and bouding boxes for one image. By default, predicted
+ labels are drawn in the top left corner of the image or corresponding
+ bounding boxes. For ground truth labels (setting True for ground_truth flag),
+ labels will be drawn in the bottom left corner.
+ Args:
+ frame (array-like): a tensor or numpy array of shape (H, W, C),
+ where H and W correspond to
+ the height and width of the image respectively. C is the number of
+ color channels. The image is required to be in RGB format since that
+ is a requirement of the Matplotlib library. The image is also expected
+ to be in the range [0, 255].
+ preds (tensor or list): If ground_truth is False, provide a float tensor of
+ shape (num_boxes, num_classes) that contains all of the confidence
+ scores of the model. For recognition task, input shape can be (num_classes,).
+ To plot true label (ground_truth is True), preds is a list contains int32
+ of the shape (num_boxes, true_class_ids) or (true_class_ids,).
+ bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
+ of the bounding boxes.
+ alpha (Optional[float]): transparency level of the bounding boxes.
+ text_alpha (Optional[float]): transparency level of the box wrapped around
+ text labels.
+ ground_truth (bool): whether the prodived bounding boxes are ground-truth.
+ Returns:
+ An image with bounding box annotations and corresponding bbox
+ labels plotted on it.
+ """
+ if isinstance(preds, torch.Tensor):
+ if preds.ndim == 1:
+ preds = preds.unsqueeze(0)
+ n_instances = preds.shape[0]
+ elif isinstance(preds, list):
+ n_instances = len(preds)
+ else:
+ logger.error("Unsupported type of prediction input.")
+ return
+
+ if ground_truth:
+ top_scores, top_classes = [None] * n_instances, preds
+
+ elif self.mode == "top-k":
+ top_scores, top_classes = torch.topk(preds, k=self.top_k)
+ top_scores, top_classes = top_scores.tolist(), top_classes.tolist()
+ elif self.mode == "thres":
+ top_scores, top_classes = [], []
+ for pred in preds:
+ mask = pred >= self.thres
+ top_scores.append(pred[mask].tolist())
+ top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
+ top_classes.append(top_class)
+
+ # Create labels top k predicted classes with their scores.
+ text_labels = []
+ for i in range(n_instances):
+ text_labels.append(
+ _create_text_labels(
+ top_classes[i],
+ top_scores[i],
+ self.class_names,
+ ground_truth=ground_truth,
+ )
+ )
+ frame_visualizer = ImgVisualizer(frame, meta=None)
+ font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9)
+ top_corner = not ground_truth
+ if bboxes is not None:
+ assert len(preds) == len(
+ bboxes
+ ), "Encounter {} predictions and {} bounding boxes".format(
+ len(preds), len(bboxes)
+ )
+ for i, box in enumerate(bboxes):
+ text = text_labels[i]
+ pred_class = top_classes[i]
+ colors = [self._get_color(pred) for pred in pred_class]
+
+ box_color = "r" if ground_truth else "g"
+ line_style = "--" if ground_truth else "-."
+ frame_visualizer.draw_box(
+ box,
+ alpha=alpha,
+ edge_color=box_color,
+ line_style=line_style,
+ )
+ frame_visualizer.draw_multiple_text(
+ text,
+ box,
+ top_corner=top_corner,
+ font_size=font_size,
+ box_facecolors=colors,
+ alpha=text_alpha,
+ )
+ else:
+ text = text_labels[0]
+ pred_class = top_classes[0]
+ colors = [self._get_color(pred) for pred in pred_class]
+ frame_visualizer.draw_multiple_text(
+ text,
+ torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]),
+ top_corner=top_corner,
+ font_size=font_size,
+ box_facecolors=colors,
+ alpha=text_alpha,
+ )
+
+ return frame_visualizer.output.get_image()
+
+ def draw_clip_range(
+ self,
+ frames: Union[torch.Tensor, np.ndarray],
+ preds: Union[torch.Tensor, List[float]],
+ bboxes: Optional[torch.Tensor] = None,
+ text_alpha: float = 0.5,
+ ground_truth: bool = False,
+ keyframe_idx: Optional[int] = None,
+ draw_range: Optional[List[int]] = None,
+ repeat_frame: int = 1,
+ ) -> List[np.ndarray]:
+ """
+ Draw predicted labels or ground truth classes to clip.
+ Draw bouding boxes to clip if bboxes is provided. Boxes will gradually
+ fade in and out the clip, centered around the clip's central frame,
+ within the provided `draw_range`.
+ Args:
+ frames (array-like): video data in the shape (T, H, W, C).
+ preds (tensor): a tensor of shape (num_boxes, num_classes) that
+ contains all of the confidence scores of the model. For recognition
+ task or for ground_truth labels, input shape can be (num_classes,).
+ bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
+ of the bounding boxes.
+ text_alpha (float): transparency label of the box wrapped around text labels.
+ ground_truth (bool): whether the prodived bounding boxes are ground-truth.
+ keyframe_idx (int): the index of keyframe in the clip.
+ draw_range (Optional[list[ints]): only draw frames in range
+ [start_idx, end_idx] inclusively in the clip. If None, draw on
+ the entire clip.
+ repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
+ time for slow-motion effect.
+ Returns:
+ A list of frames with bounding box annotations and corresponding
+ bbox labels ploted on them.
+ """
+ if draw_range is None:
+ draw_range = [0, len(frames) - 1]
+ if draw_range is not None:
+ draw_range[0] = max(0, draw_range[0])
+ left_frames = frames[: draw_range[0]]
+ right_frames = frames[draw_range[1] + 1 :]
+
+ draw_frames = frames[draw_range[0] : draw_range[1] + 1]
+ if keyframe_idx is None:
+ keyframe_idx = len(frames) // 2
+
+ img_ls = (
+ list(left_frames)
+ + self.draw_clip(
+ draw_frames,
+ preds,
+ bboxes=bboxes,
+ text_alpha=text_alpha,
+ ground_truth=ground_truth,
+ keyframe_idx=keyframe_idx - draw_range[0],
+ repeat_frame=repeat_frame,
+ )
+ + list(right_frames)
+ )
+
+ return img_ls
+
+ def draw_clip(
+ self,
+ frames: Union[torch.Tensor, np.ndarray],
+ preds: Union[torch.Tensor, List[float]],
+ bboxes: Optional[torch.Tensor] = None,
+ text_alpha: float = 0.5,
+ ground_truth: bool = False,
+ keyframe_idx: Optional[int] = None,
+ repeat_frame: int = 1,
+ ) -> List[np.ndarray]:
+ """
+ Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
+ if bboxes is provided. Boxes will gradually fade in and out the clip, centered
+ around the clip's central frame.
+ Args:
+ frames (array-like): video data in the shape (T, H, W, C).
+ preds (tensor): a tensor of shape (num_boxes, num_classes) that contains
+ all of the confidence scores of the model. For recognition task or for
+ ground_truth labels, input shape can be (num_classes,).
+ bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
+ of the bounding boxes.
+ text_alpha (float): transparency label of the box wrapped around text labels.
+ ground_truth (bool): whether the prodived bounding boxes are ground-truth.
+ keyframe_idx (int): the index of keyframe in the clip.
+ repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
+ time for slow-motion effect.
+ Returns:
+ A list of frames with bounding box annotations and corresponding
+ bbox labels plotted on them.
+ """
+ assert repeat_frame >= 1, "`repeat_frame` must be a positive integer."
+
+ repeated_seq = range(0, len(frames))
+ repeated_seq = list(
+ itertools.chain.from_iterable(
+ itertools.repeat(x, repeat_frame) for x in repeated_seq
+ )
+ )
+
+ frames, adjusted = self._adjust_frames_type(frames)
+ if keyframe_idx is None:
+ half_left = len(repeated_seq) // 2
+ half_right = (len(repeated_seq) + 1) // 2
+ else:
+ mid = int((keyframe_idx / len(frames)) * len(repeated_seq))
+ half_left = mid
+ half_right = len(repeated_seq) - mid
+
+ alpha_ls = np.concatenate(
+ [
+ np.linspace(0, 1, num=half_left),
+ np.linspace(1, 0, num=half_right),
+ ]
+ )
+ text_alpha = text_alpha
+ frames = frames[repeated_seq]
+ img_ls = []
+ for alpha, frame in zip(alpha_ls, frames):
+ draw_img = self.draw_one_frame(
+ frame,
+ preds,
+ bboxes,
+ alpha=alpha,
+ text_alpha=text_alpha,
+ ground_truth=ground_truth,
+ )
+ if adjusted:
+ draw_img = draw_img.astype("float32") / 255
+
+ img_ls.append(draw_img)
+
+ return img_ls
+
+ def _adjust_frames_type(
+ self, frames: torch.Tensor
+ ) -> Tuple[List[np.ndarray], bool]:
+ """
+ Modify video data to have dtype of uint8 and values range in [0, 255].
+ Args:
+ frames (array-like): 4D array of shape (T, H, W, C).
+ Returns:
+ frames (list of frames): list of frames in range [0, 1].
+ adjusted (bool): whether the original frames need adjusted.
+ """
+ assert (
+ frames is not None and len(frames) != 0
+ ), "Frames does not contain any values"
+ frames = np.array(frames)
+ assert np.array(frames).ndim == 4, "Frames must have 4 dimensions"
+ adjusted = False
+ if frames.dtype in [np.float32, np.float64]:
+ frames *= 255
+ frames = frames.astype(np.uint8)
+ adjusted = True
+
+ return frames, adjusted
+
+ def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None:
+ """
+ Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`.
+ Args:
+ common_class_names (Optional[list of str]): a list of common class names.
+ """
+ common_class_ids = []
+ if common_class_names is not None:
+ common_classes = set(common_class_names)
+
+ for key, name in self.class_names.items():
+ if name in common_classes:
+ common_class_ids.append(key)
+ else:
+ common_class_ids = list(range(self.num_classes))
+
+ thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres)
+ thres_array[common_class_ids] = self.thres
+ self.thres = torch.from_numpy(thres_array)
diff --git a/code/pytorchvideo/website/.dockerignore b/code/pytorchvideo/website/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..27d2dae2b493488b48bdb18b95af471821ece9bf
--- /dev/null
+++ b/code/pytorchvideo/website/.dockerignore
@@ -0,0 +1,2 @@
+*/node_modules
+*.log
diff --git a/code/pytorchvideo/website/.gitignore b/code/pytorchvideo/website/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5395ea795d62b04c6645f5746b9cc667d26a07bb
--- /dev/null
+++ b/code/pytorchvideo/website/.gitignore
@@ -0,0 +1,12 @@
+.DS_Store
+
+node_modules
+
+lib/core/metadata.js
+lib/core/MetadataBlog.js
+
+website/translated_docs
+website/build/
+website/yarn.lock
+website/node_modules
+website/i18n/*
diff --git a/code/pytorchvideo/website/docs/tutorial_accelerator_build_your_model.md b/code/pytorchvideo/website/docs/tutorial_accelerator_build_your_model.md
new file mode 100644
index 0000000000000000000000000000000000000000..b62cba0c5ee1ec6d1f88d5490532f411d8883bbd
--- /dev/null
+++ b/code/pytorchvideo/website/docs/tutorial_accelerator_build_your_model.md
@@ -0,0 +1,440 @@
+---
+id: tutorial_accelerator_build_your_model
+title: Build your efficient model with PytorchVideo/Accelerator
+---
+
+
+## Introduction
+In this tutorial, we will go through:
+- Basics of efficient blocks in PytorchVideo/Accelerator;
+- Design, train and deploy a model composed of efficient blocks for mobile CPU.
+
+## Basics of efficient blocks in PytorchVideo/Accelerator
+Efficient blocks are blocks with high efficiency. For a target device, we benchmark efficiency of basic network components and provide a collection of efficient blocks under `pytorchvideo/layers/accelerator/` (for simple layers) and `pytorchvideo/models/accelerator/` (for complex modules such as residual block). Inferencing of a model built up with corresponding efficient blocks on target device is guranteed to be efficient.
+
+Each efficient block module is an instance of nn.Module, and has two forms: **original form** (for training) and **deploy form** (for inference). When in original form, the efficient block module has exactly the same behavior as a corresponding vanilla nn.Module for both forward and backward operation. User can freely mix and match efficient blocks for the same target device and build up their own model. Once model is built and trained, user can convert each efficient block in model into deploy form. The conversion will do graph and kernel optimization on each efficient block, and efficient block in deploy form is arithmetically equivalent to original form but has much higher efficiency during inference.
+
+## Design, train and deploy a model composed of efficient blocks for mobile CPU
+### Build a model
+In this section, let's go through the process of design, train and deploy using a example toy model using efficient blocks under `pytorchvideo/layers/accelerator/mobile_cpu` and `pytorchvideo/models/accelerator/mobile_cpu`, which includes:
+- One conv3d head layer with 5x1x1 kernel followed by ReLU activation;
+- One residual block with squeeze-excite;
+- One average pool and fully connected layer as final output.
+
+First, let's import efficient blocks.
+
+
+```python
+# Imports
+import torch.nn as nn
+from pytorchvideo.layers.accelerator.mobile_cpu.activation_functions import (
+ supported_act_functions,
+)
+from pytorchvideo.layers.accelerator.mobile_cpu.convolutions import (
+ Conv3d5x1x1BnAct,
+)
+from pytorchvideo.models.accelerator.mobile_cpu.residual_blocks import (
+ X3dBottleneckBlock,
+)
+from pytorchvideo.layers.accelerator.mobile_cpu.pool import AdaptiveAvgPool3dOutSize1
+from pytorchvideo.layers.accelerator.mobile_cpu.fully_connected import FullyConnected
+
+```
+
+Then we can build a model using those efficient blocks.
+
+
+```python
+class MyNet(nn.Module):
+ def __init__(
+ self,
+ in_channel=3, # input channel of first 5x1x1 layer
+ residual_block_channel=24, # input channel of residual block
+ expansion_ratio=3, # expansion ratio of residual block
+ num_classes=4, # final output classes
+ ):
+ super().__init__()
+ # s1 - 5x1x1 conv3d layer
+ self.s1 = Conv3d5x1x1BnAct(
+ in_channel,
+ residual_block_channel,
+ bias=False,
+ groups=1,
+ use_bn=False,
+ )
+ # s2 - residual block
+ mid_channel = int(residual_block_channel * expansion_ratio)
+ self.s2 = X3dBottleneckBlock(
+ in_channels=residual_block_channel,
+ mid_channels=mid_channel,
+ out_channels=residual_block_channel,
+ use_residual=True,
+ spatial_stride=1,
+ se_ratio=0.0625,
+ act_functions=("relu", "swish", "relu"),
+ use_bn=(True, True, True),
+ )
+ # Average pool and fully connected layer
+ self.avg_pool = AdaptiveAvgPool3dOutSize1()
+ self.projection = FullyConnected(residual_block_channel, num_classes, bias=True)
+ self.act = supported_act_functions['relu']()
+
+ def forward(self, x):
+ x = self.s1(x)
+ x = self.s2(x)
+ x = self.avg_pool(x)
+ # (N, C, T, H, W) -> (N, T, H, W, C).
+ x = x.permute((0, 2, 3, 4, 1))
+ x = self.projection(x)
+ # Performs fully convolutional inference.
+ if not self.training:
+ x = self.act(x)
+ x = x.mean([1, 2, 3])
+ x = x.view(x.shape[0], -1)
+
+ return x
+```
+
+We can instantiate MyNet and its efficient blocks will be in original form.
+
+
+```python
+net_inst = MyNet()
+print(net_inst)
+```
+
+ MyNet(
+ (s1): Conv3d5x1x1BnAct(
+ (kernel): Sequential(
+ (conv): Conv3d(3, 24, kernel_size=(5, 1, 1), stride=(1, 1, 1), padding=(2, 0, 0), bias=False)
+ (act): ReLU(
+ (act): ReLU(inplace=True)
+ )
+ )
+ )
+ (s2): X3dBottleneckBlock(
+ (_residual_add_func): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (final_act): ReLU(
+ (act): ReLU(inplace=True)
+ )
+ (layers): Sequential(
+ (conv_0): Conv3dPwBnAct(
+ (kernel): Sequential(
+ (conv): Conv3d(24, 72, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
+ (bn): BatchNorm3d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (act): ReLU(
+ (act): ReLU(inplace=True)
+ )
+ )
+ )
+ (conv_1): Conv3d3x3x3DwBnAct(
+ (kernel): Sequential(
+ (conv): Conv3d(72, 72, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=72, bias=False)
+ (bn): BatchNorm3d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (act): Identity(
+ (act): Identity()
+ )
+ )
+ )
+ (se): SqueezeExcitation(
+ (se): SqueezeExcitation(
+ (block): Sequential(
+ (0): Conv3d(72, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1))
+ (1): ReLU()
+ (2): Conv3d(8, 72, kernel_size=(1, 1, 1), stride=(1, 1, 1))
+ (3): Sigmoid()
+ )
+ )
+ )
+ (act_func_1): Swish(
+ (act): Swish()
+ )
+ (conv_2): Conv3dPwBnAct(
+ (kernel): Sequential(
+ (conv): Conv3d(72, 24, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
+ (bn): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ (act): Identity(
+ (act): Identity()
+ )
+ )
+ )
+ )
+ )
+ (avg_pool): AdaptiveAvgPool3dOutSize1(
+ (pool): AdaptiveAvgPool3d(output_size=1)
+ )
+ (projection): FullyConnected(
+ (model): Linear(in_features=24, out_features=4, bias=True)
+ )
+ (act): ReLU(
+ (act): ReLU(inplace=True)
+ )
+ )
+
+
+### Train model
+Then we can train the model with your dataset/optimizer. Here we skip this training step, and just leave the weight as initial value.
+
+### Deploy model
+Now the model is ready to deploy. First of all, let's convert the model into deploy form. In order to do that, we need to use `convert_to_deployable_form` utility and provide an example input tensor to the model. Note that once the model is converted into deploy form, the input size should be the same as the example input tensor size during conversion.
+
+
+```python
+import torch
+from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (
+ convert_to_deployable_form,
+)
+input_blob_size = (1, 3, 4, 6, 6)
+input_tensor = torch.randn(input_blob_size)
+net_inst_deploy = convert_to_deployable_form(net_inst, input_tensor)
+
+```
+
+We can see that the network graph has been changed after conversion, which did kernel and graph optimization.
+
+
+```python
+print(net_inst_deploy)
+```
+
+ MyNet(
+ (s1): Conv3d5x1x1BnAct(
+ (kernel): Sequential(
+ (conv): _Conv3dTemporalKernel5Decomposed(
+ (_conv2d_0): Conv2d(3, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ (_conv2d_1): Conv2d(3, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ (_conv2d_2): Conv2d(3, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ (_conv2d_3): Conv2d(3, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ (_conv2d_4): Conv2d(3, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ (_add_funcs): ModuleList(
+ (0): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (1): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (2): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (3): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (4): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (5): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (6): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (7): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (8): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (9): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ )
+ (_cat_func): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ )
+ (act): ReLU(
+ (act): ReLU(inplace=True)
+ )
+ )
+ )
+ (s2): X3dBottleneckBlock(
+ (_residual_add_func): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (final_act): ReLU(
+ (act): ReLU(inplace=True)
+ )
+ (layers): Sequential(
+ (conv_0): Conv3dPwBnAct(
+ (kernel): Sequential(
+ (0): _Reshape()
+ (1): Sequential(
+ (conv): ConvReLU2d(
+ (0): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1))
+ (1): ReLU(inplace=True)
+ )
+ (bn): Identity()
+ (act): ReLU(
+ (act): Identity()
+ )
+ )
+ (2): _Reshape()
+ )
+ )
+ (conv_1): Conv3d3x3x3DwBnAct(
+ (kernel): Sequential(
+ (conv): _Conv3dTemporalKernel3Decomposed(
+ (_conv2d_3_3_0): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)
+ (_conv2d_3_3_2): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72, bias=False)
+ (_conv2d_3_3_1): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72)
+ (_add_funcs): ModuleList(
+ (0): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (1): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (2): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (3): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (4): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ (5): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ )
+ (_cat_func): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ )
+ (bn): Identity()
+ (act): Identity(
+ (act): Identity()
+ )
+ )
+ )
+ (se): SqueezeExcitation(
+ (se): _SkipConnectMul(
+ (layer): Sequential(
+ (0): AdaptiveAvgPool3d(output_size=1)
+ (1): _Reshape()
+ (2): Linear(in_features=72, out_features=8, bias=True)
+ (3): ReLU()
+ (4): Linear(in_features=8, out_features=72, bias=True)
+ (5): Sigmoid()
+ (6): _Reshape()
+ )
+ (mul_func): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ )
+ )
+ (act_func_1): Swish(
+ (act): _NaiveSwish(
+ (mul_func): FloatFunctional(
+ (activation_post_process): Identity()
+ )
+ )
+ )
+ (conv_2): Conv3dPwBnAct(
+ (kernel): Sequential(
+ (0): _Reshape()
+ (1): Sequential(
+ (conv): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1))
+ (bn): Identity()
+ (act): Identity(
+ (act): Identity()
+ )
+ )
+ (2): _Reshape()
+ )
+ )
+ )
+ )
+ (avg_pool): AdaptiveAvgPool3dOutSize1(
+ (pool): AvgPool3d(kernel_size=(4, 6, 6), stride=(4, 6, 6), padding=0)
+ )
+ (projection): FullyConnected(
+ (model): Linear(in_features=24, out_features=4, bias=True)
+ )
+ (act): ReLU(
+ (act): ReLU(inplace=True)
+ )
+ )
+
+
+Let's check whether the network after conversion is arithmetically equivalent. We expect the output to be very close before/after conversion, with some small difference due to numeric noise from floating point operation.
+
+
+```python
+net_inst.eval()
+out_ref = net_inst(input_tensor)
+out = net_inst_deploy(input_tensor)
+
+max_err = float(torch.max(torch.abs(out_ref - out)))
+print(f"max error is {max_err}")
+```
+
+ max error is 2.9802322387695312e-08
+
+
+Next we have two options: either deploy floating point model, or quantize model into int8 and then deploy.
+
+Let's first assume we want to deploy floating point model. In this case, all we need to do is to export jit trace and then apply `optimize_for_mobile` for final optimization.
+
+
+```python
+from torch.utils.mobile_optimizer import (
+ optimize_for_mobile,
+)
+traced_model = torch.jit.trace(net_inst_deploy, input_tensor, strict=False)
+traced_model_opt = optimize_for_mobile(traced_model)
+# Here we can save the traced_model_opt to JIT file using traced_model_opt.save()
+```
+
+Alternatively, we may also want to deploy a quantized model. Efficient blocks are quantization-friendly by design - just wrap the model in deploy form with `QuantStub/DeQuantStub` and it is ready for Pytorch eager mode quantization.
+
+
+```python
+# Wrapper class for adding QuantStub/DeQuantStub.
+class quant_stub_wrapper(nn.Module):
+ def __init__(self, module_in):
+ super().__init__()
+ self.quant = torch.quantization.QuantStub()
+ self.model = module_in
+ self.dequant = torch.quantization.DeQuantStub()
+ def forward(self, x):
+ x = self.quant(x)
+ x = self.model(x)
+ x = self.dequant(x)
+ return x
+```
+
+
+```python
+net_inst_quant_stub_wrapper = quant_stub_wrapper(net_inst_deploy)
+```
+
+Preparation step of quantization. Fusion has been done for efficient blocks automatically during `convert_to_deployable_form`, so we can just proceed to `torch.quantization.prepare`. Assume we are quantizing model for mobile devices in this tutorial (e.g., using `qnnpack` as backend) - please switch to proper backend and set corresponding quantization `qconfig` for your case.
+
+
+```python
+net_inst_quant_stub_wrapper.qconfig = torch.quantization.torch.quantization.get_default_qconfig("qnnpack")
+torch.backends.quantized.engine = "qnnpack"
+net_inst_quant_stub_wrapper_prepared = torch.quantization.prepare(net_inst_quant_stub_wrapper)
+```
+
+Calibration and quantization. After preparation we will do calibration of quantization by feeding calibration dataset (skipped here) and then do quantization.
+
+
+```python
+# calibration is skipped here.
+net_inst_quant_stub_wrapper_quantized = torch.quantization.convert(net_inst_quant_stub_wrapper_prepared)
+```
+
+
+Then we can export trace of int8 model and deploy on mobile devices.
+
+
+```python
+traced_model_int8 = torch.jit.trace(net_inst_quant_stub_wrapper_quantized, input_tensor, strict=False)
+traced_model_int8_opt = optimize_for_mobile(traced_model_int8)
+# Here we can save the traced_model_opt to JIT file using traced_model_int8_opt.save()
+```
+
diff --git a/code/pytorchvideo/website/docs/tutorial_accelerator_use_accelerator_model_zoo.md b/code/pytorchvideo/website/docs/tutorial_accelerator_use_accelerator_model_zoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..9df08ca8b3fa19384fdafb6a6364d327de4cc493
--- /dev/null
+++ b/code/pytorchvideo/website/docs/tutorial_accelerator_use_accelerator_model_zoo.md
@@ -0,0 +1,118 @@
+---
+id: tutorial_accelerator_use_accelerator_model_zoo
+title: Use PytorchVideo/Accelerator Model Zoo
+---
+
+
+## Introduction
+This tutorial goes through how to use model zoo provided by PytorchVideo/Accelerator. To use model zoo in PytorchVideo/Accelerator, we should generally follow several steps:
+- Use model builder to build selected model;
+- Load pretrain checkpoint;
+- (Optional) Finetune;
+- Deploy.
+
+## Use model builder to build selected model
+We use model builder in PytorchVideo/Accelerator model zoo to build pre-defined efficient model. Here we use EfficientX3D-XS (for mobile_cpu) as an example. For more available models and details, please refer to [this page].
+
+EfficientX3D-XS is an implementation of X3D-XS network as described in [X3D paper](https://arxiv.org/abs/2004.04730) using efficient blocks. It is arithmetically equivalent with X3D-XS, but our benchmark on mobile phone shows 4.6X latency reduction compared with vanilla implementation.
+
+In order to build EfficientX3D-XS, we simply do the following:
+
+
+```python
+from pytorchvideo.models.accelerator.mobile_cpu.efficient_x3d import EfficientX3d
+model_efficient_x3d_xs = EfficientX3d(expansion='XS', head_act='identity')
+```
+
+Note that now the efficient blocks in the model are in original form, so the model is good for further training.
+
+## Load pretrain checkpoint and (optional) finetune
+For each model in model zoo, we provide pretrain checkpoint state_dict for model in original form. See [this page] for details about checkpoints and where to download them.
+
+
+```python
+from torch.hub import load_state_dict_from_url
+checkpoint_path = 'https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/efficient_x3d_xs_original_form.pyth'
+checkpoint = load_state_dict_from_url(checkpoint_path)
+
+model_efficient_x3d_xs.load_state_dict(checkpoint)
+```
+
+Now the model is ready for fine-tune.
+
+## Deploy
+Now the model is ready to deploy. First of all, let's convert the model into deploy form. In order to do that, we need to use `convert_to_deployable_form` utility and provide an example input tensor to the model. Note that once the model is converted into deploy form, the input size should be the same as the example input tensor size during conversion.
+
+
+```python
+import torch
+from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (
+ convert_to_deployable_form,
+)
+input_blob_size = (1, 3, 4, 160, 160)
+input_tensor = torch.randn(input_blob_size)
+model_efficient_x3d_xs_deploy = convert_to_deployable_form(model_efficient_x3d_xs, input_tensor)
+```
+
+Next we have two options: either deploy floating point model, or quantize model into int8 and then deploy.
+
+Let's first assume we want to deploy floating point model. In this case, all we need to do is to export jit trace and then apply `optimize_for_mobile` for final optimization.
+
+
+```python
+from torch.utils.mobile_optimizer import (
+ optimize_for_mobile,
+)
+traced_model = torch.jit.trace(model_efficient_x3d_xs_deploy, input_tensor, strict=False)
+traced_model_opt = optimize_for_mobile(traced_model)
+# Here we can save the traced_model_opt to JIT file using traced_model_opt.save()
+```
+
+Alternatively, we may also want to deploy a quantized model. Efficient blocks are quantization-friendly by design - just wrap the model in deploy form with `QuantStub/DeQuantStub` and it is ready for Pytorch eager mode quantization.
+
+
+```python
+# Wrapper class for adding QuantStub/DeQuantStub.
+class quant_stub_wrapper(nn.Module):
+ def __init__(self, module_in):
+ super().__init__()
+ self.quant = torch.quantization.QuantStub()
+ self.model = module_in
+ self.dequant = torch.quantization.DeQuantStub()
+ def forward(self, x):
+ x = self.quant(x)
+ x = self.model(x)
+ x = self.dequant(x)
+ return x
+```
+
+
+```python
+model_efficient_x3d_xs_deploy_quant_stub_wrapper = quant_stub_wrapper(model_efficient_x3d_xs_deploy)
+```
+
+Preparation step of quantization. Fusion has been done for efficient blocks automatically during `convert_to_deployable_form`, so we can just proceed to `torch.quantization.prepare`
+
+
+```python
+model_efficient_x3d_xs_deploy_quant_stub_wrapper.qconfig = torch.quantization.default_qconfig
+model_efficient_x3d_xs_deploy_quant_stub_wrapper_prepared = torch.quantization.prepare(model_efficient_x3d_xs_deploy_quant_stub_wrapper)
+```
+
+Calibration and quantization. After preparation we will do calibration of quantization by feeding calibration dataset (skipped here) and then do quantization.
+
+
+```python
+# calibration is skipped here.
+model_efficient_x3d_xs_deploy_quant_stub_wrapper_quantized = torch.quantization.convert(model_efficient_x3d_xs_deploy_quant_stub_wrapper_prepared)
+```
+
+Then we can export trace of int8 model and deploy on mobile devices.
+
+
+```python
+traced_model_int8 = torch.jit.trace(model_efficient_x3d_xs_deploy_quant_stub_wrapper_quantized, input_tensor, strict=False)
+traced_model_int8_opt = optimize_for_mobile(traced_model_int8)
+# Here we can save the traced_model_opt to JIT file using traced_model_int8_opt.save()
+```
+
diff --git a/code/pytorchvideo/website/docs/tutorial_accelerator_use_model_transmuter.md b/code/pytorchvideo/website/docs/tutorial_accelerator_use_model_transmuter.md
new file mode 100644
index 0000000000000000000000000000000000000000..45e74d12bb06a6c36e665d06f7e1c0922f3a0bfe
--- /dev/null
+++ b/code/pytorchvideo/website/docs/tutorial_accelerator_use_model_transmuter.md
@@ -0,0 +1,98 @@
+---
+id: tutorial_accelerator_use_model_transmuter
+title: Accelerate your model with model transmuter in PytorchVideo/Accelerator
+---
+
+
+## Introduction
+Got your own model, but still want to fully leverage efficient blocks in PytorchVideo/Accelerator? No problem, model transmuter can help you.
+Model transmuter is a utility in PytorchVideo/Accelerator that takes user defined model, and replace modules in user model with equivalent efficient block when possible.
+In this tutorial, we will go through typical steps of using model transmuter, including:
+- Use model transmuter to replace modules in user model with efficient blocks
+- Convert model into deploy form and deploy
+
+## Use model transmuter to replace modules in user model with efficient blocks
+First, let's assume user has following model to be transmuted:
+
+
+```python
+import torch
+import torch.nn as nn
+
+class user_model_residual_block(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.stem0 = nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
+ self.stem1 = nn.Conv3d(3, 3, kernel_size=(5, 1, 1), padding=(2, 0, 0))
+ self.pw = nn.Conv3d(3, 6, kernel_size=1)
+ self.relu = nn.ReLU()
+ self.dw = nn.Conv3d(6, 6, kernel_size=3, padding=1, groups=6)
+ self.relu1 = nn.ReLU()
+ self.pwl = nn.Conv3d(6, 3, kernel_size=1)
+ self.relu2 = nn.ReLU()
+
+ def forward(self, x):
+ out = self.stem0(x)
+ out = self.stem1(out)
+ out = self.pw(out)
+ out = self.relu(out)
+ out = self.dw(out)
+ out = self.relu1(out)
+ out = self.pwl(out)
+ return self.relu2(out + x)
+```
+
+Then, let's use model transmuter by importing transmuter for targeting device. In this tutorial, we are using mobile cpu as example. Therefore we will import (1) model transmuter for mobile cpu and (2) top-level wrapper of model transmuter.
+
+
+```python
+import pytorchvideo.accelerator.deployment.mobile_cpu.transmuter # mobile cpu model transmuter
+from pytorchvideo.accelerator.deployment.common.model_transmuter import transmute_model # top-level wrapper of model transmuter
+```
+
+We instantiate one user_model_residual_block, and transmute it by calling `transmute_model` with argument of `target_device="mobile_cpu"`
+
+
+```python
+model_transmute = user_model_residual_block()
+transmute_model(
+ model_transmute,
+ target_device="mobile_cpu",
+)
+```
+
+If we print the model, we will find that the some of modules in model has been replaced. In geenral, model transmuter will replace one submodule if its equivalent efficient block is found, otherwise that submodule will be kept intact.
+
+
+## Convert model into deploy form and deploy
+Now the model is ready to deploy. First of all, let's convert the model into deploy form. In order to do that, we need to use `convert_to_deployable_form` utility and provide an example input tensor to the model. `convert_to_deployable_form` will convert any instance of `EfficientBlockBase` (base class for efficient blocks in PytorchVideo/Accelerator) into deploy form, while leave other modules unchanged.
+Note that once the model is converted into deploy form, the input size should be the same as the example input tensor size during conversion.
+
+
+```python
+# Define example input tensor
+input_blob_size = (1, 3, 4, 6, 6)
+input_tensor = torch.randn(input_blob_size)
+```
+
+
+```python
+from pytorchvideo.accelerator.deployment.mobile_cpu.utils.model_conversion import (
+ convert_to_deployable_form,
+)
+model_transmute_deploy = convert_to_deployable_form(
+ model_transmute, input_tensor
+)
+```
+
+Currently model transmuter only supports fp32 operation, and it will support int8 with incoming torch.fx quantization mode. In this tutorial, we assume deploy transmuted model without quantization. In this case, all we need to do is to export jit trace and then apply `optimize_for_mobile` for final optimization.
+
+
+```python
+from torch.utils.mobile_optimizer import (
+ optimize_for_mobile,
+)
+traced_model = torch.jit.trace(model_transmute_deploy, input_tensor, strict=False)
+traced_model_opt = optimize_for_mobile(traced_model)
+# Here we can save the traced_model_opt to JIT file using traced_model_opt.save()
+```
diff --git a/code/pytorchvideo/website/docs/tutorial_classification.md b/code/pytorchvideo/website/docs/tutorial_classification.md
new file mode 100644
index 0000000000000000000000000000000000000000..881d67208a6b3a58b762b6236cc535f296ae9021
--- /dev/null
+++ b/code/pytorchvideo/website/docs/tutorial_classification.md
@@ -0,0 +1,227 @@
+---
+id: tutorial_classification
+title: Training a PyTorchVideo classification model
+---
+
+# Introduction
+
+In this tutorial we will show how to build a simple video classification training pipeline using PyTorchVideo models, datasets and transforms. We'll be using a 3D ResNet [1] for the model, Kinetics [2] for the dataset and a standard video transform augmentation recipe. As PyTorchVideo doesn't contain training code, we'll use [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) - a lightweight PyTorch training framework - to help out. Don't worry if you don't have Lightning experience, we'll explain what's needed as we go along.
+
+[1] He, Kaiming, et al. Deep Residual Learning for Image Recognition. ArXiv:1512.03385, 2015.
+
+[2] W. Kay, et al. The kinetics human action video dataset. arXiv preprint arXiv:1705.06950, 2017.
+
+# Dataset
+
+To start off with, let's prepare the data and setup the PyTorchVideo Kinetics data loader using a [``pytorch_lightning.LightningDataModule``](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.datamodule.html#pytorch_lightning.core.datamodule.LightningDataModule) . A ``LightningDataModule`` is a wrapper that defines the train, val and test data partitions, we'll use it to wrap the PyTorchVideo Kinetics dataset below.
+
+To prepare the Kinetics dataset, you'll need the list of videos found on the Kinetics website [here](https://deepmind.com/research/open-source/kinetics) (any of the Kinetics versions will work). You'll then need the official [download script](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics) to download the videos. Once downloaded, point the ``pytorchvideo.data.Kinetics`` ``data_path`` arg to the folder of classes (each class folder contains the videos) and the data loader will work. Note that for our model-zoo, we also downsample the Kinetics videos to 256 on the short size to speed up training, see more details in the [data preparation docs](https://pytorchvideo.readthedocs.io/en/latest/data_preparation.html).
+
+The PyTorchVideo Kinetics dataset is just an alias for the general [``pytorchvideo.data.LabeledVideoDataset``](http://pytorchvideo.org/api/data/encoded_video.html#pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset) class. If you look at its constructor, you'll notice that most args are what you'd expect (e.g. path to data). However, there are a few args that are more specific to PyTorchVideo datasets:
+- video_sampler - defining the order to sample a video at each iteration. The default is a "random".
+- clip_sampler - defining how to sample a clip from the chosen video at each iteration. For a train partition it is typical to use a "random" clip sampler (i.e. take a random clip of the specified duration from the video). For testing, typically you'll use "uniform" (i.e. uniformly sample all clips of the specified duration from the video) to ensure the entire video is sampled in each epoch.
+- transform - this provides a way to apply user defined data preprocessing or augmentation before batch collating by the PyTorch data loader. We'll show an example using this later.
+
+
+```python
+import os
+import pytorch_lightning
+import pytorchvideo.data
+import torch.utils.data
+
+class KineticsDataModule(pytorch_lightning.LightningDataModule):
+
+ # Dataset configuration
+ _DATA_PATH =
+ _CLIP_DURATION = 2 # Duration of sampled clip for each video
+ _BATCH_SIZE = 8
+ _NUM_WORKERS = 8 # Number of parallel processes fetching data
+
+ def train_dataloader(self):
+ """
+ Create the Kinetics train partition from the list of video labels
+ in {self._DATA_PATH}/train
+ """
+ train_dataset = pytorchvideo.data.Kinetics(
+ data_path=os.path.join(self._DATA_PATH, "train"),
+ clip_sampler=pytorchvideo.data.make_clip_sampler("random", self._CLIP_DURATION),
+ decode_audio=False,
+ )
+ return torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=self._BATCH_SIZE,
+ num_workers=self._NUM_WORKERS,
+ )
+
+ def val_dataloader(self):
+ """
+ Create the Kinetics validation partition from the list of video labels
+ in {self._DATA_PATH}/val
+ """
+ val_dataset = pytorchvideo.data.Kinetics(
+ data_path=os.path.join(self._DATA_PATH, "val"),
+ clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", self._CLIP_DURATION),
+ decode_audio=False,
+ )
+ return torch.utils.data.DataLoader(
+ val_dataset,
+ batch_size=self._BATCH_SIZE,
+ num_workers=self._NUM_WORKERS,
+ )
+```
+
+# Transforms
+
+As mentioned above, PyTorchVideo datasets take a "transform" callable arg that defines custom processing (e.g. augmentations, normalization) that's applied to each clip. The callable arg takes a clip dictionary defining the different modalities and metadata. ``pytorchvideo.data.Kinetics`` clips have the following dictionary format:
+
+```python
+ {
+ 'video': , # Shape: (C, T, H, W)
+ 'audio': , # Shape: (S)
+ 'label': , # Integer defining class annotation
+ 'video_name': , # Video file path stem
+ 'video_index': , # index of video used by sampler
+ 'clip_index': # index of clip sampled within video
+ }
+```
+
+PyTorchVideo provides several transforms which you can see in the [docs](https://pytorchvideo.readthedocs.io/en/latest/transforms.html) Notably, PyTorchVideo provides dictionary transforms that can be used to easily interoperate with other domain specific libraries. For example, [``pytorchvideo.transforms.ApplyTransformToKey(key, transform)``](https://pytorchvideo.readthedocs.io/en/latest/api/transforms/transforms.html), can be used to apply domain specific transforms to a specific dictionary key. For video tensors we use the same tensor shape as TorchVision and for audio we use TorchAudio tensor shapes, making it east to apply their transforms alongside PyTorchVideo ones.
+
+Below we revise the ``LightningDataModule`` from the last section to include transforms coming from both TorchVision and PyTorchVideo. For brevity we'll just show the ``KineticsDataModule.train_dataloader`` method. The validation dataset transforms would be the same just without the augmentations (``RandomShortSideScale``, ``RandomCropVideo``, ``RandomHorizontalFlipVideo``).
+
+```python
+from pytorchvideo.transforms import (
+ ApplyTransformToKey,
+ Normalize,
+ RandomShortSideScale,
+ RemoveKey,
+ ShortSideScale,
+ UniformTemporalSubsample
+)
+
+from torchvision.transforms import (
+ Compose,
+ Lambda,
+ RandomCrop,
+ RandomHorizontalFlip
+)
+
+class KineticsDataModule(pytorch_lightning.LightningDataModule):
+
+# ...
+
+ def train_dataloader(self):
+ """
+ Create the Kinetics train partition from the list of video labels
+ in {self._DATA_PATH}/train.csv. Add transform that subsamples and
+ normalizes the video before applying the scale, crop and flip augmentations.
+ """
+ train_transform = Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(8),
+ Lambda(lambda x: x / 255.0),
+ Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
+ RandomShortSideScale(min_size=256, max_size=320),
+ RandomCrop(244),
+ RandomHorizontalFlip(p=0.5),
+ ]
+ ),
+ ),
+ ]
+ )
+ train_dataset = pytorchvideo.data.Kinetics(
+ data_path=os.path.join(self._DATA_PATH, "train.csv"),
+ clip_sampler=pytorchvideo.data.make_clip_sampler("random", self._CLIP_DURATION),
+ transform=train_transform
+ )
+ return torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=self._BATCH_SIZE,
+ num_workers=self._NUM_WORKERS,
+ )
+
+# ...
+
+```
+
+# Model
+
+All PyTorchVideo models and layers can be built with simple, reproducible factory functions. We call this the "flat" model interface because the args don't require hierarchical configs to be used. An example building a default ResNet can be found below. See the [docs](https://pytorchvideo.readthedocs.io/en/latest/_modules/pytorchvideo/models/resnet.html#create_bottleneck_block) for more configuration options.
+
+```python
+import pytorchvideo.models.resnet
+
+def make_kinetics_resnet():
+ return pytorchvideo.models.resnet.create_resnet(
+ input_channel=3, # RGB input from Kinetics
+ model_depth=50, # For the tutorial let's just use a 50 layer network
+ model_num_class=400, # Kinetics has 400 classes so we need out final head to align
+ norm=nn.BatchNorm3d,
+ activation=nn.ReLU,
+ )
+```
+
+# Putting it all together
+
+To put everything together, let's create a [``pytorch_lightning.LightningModule``](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html). This defines the train and validation step code (i.e. the code inside the training and evaluation loops), and the optimizer.
+
+```python
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class VideoClassificationLightningModule(pytorch_lightning.LightningModule):
+ def __init__(self):
+ super().__init__()
+ self.model = make_kinetics_resnet()
+
+ def forward(self, x):
+ return self.model(x)
+
+ def training_step(self, batch, batch_idx):
+ # The model expects a video tensor of shape (B, C, T, H, W), which is the
+ # format provided by the dataset
+ y_hat = self.model(batch["video"])
+
+ # Compute cross entropy loss, loss.backwards will be called behind the scenes
+ # by PyTorchLightning after being returned from this method.
+ loss = F.cross_entropy(y_hat, batch["label"])
+
+ # Log the train loss to Tensorboard
+ self.log("train_loss", loss.item())
+
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ y_hat = self.model(batch["video"])
+ loss = F.cross_entropy(y_hat, batch["label"])
+ self.log("val_loss", loss)
+ return loss
+
+ def configure_optimizers(self):
+ """
+ Setup the Adam optimizer. Note, that this function also can return a lr scheduler, which is
+ usually useful for training video models.
+ """
+ return torch.optim.Adam(self.parameters(), lr=1e-1)
+```
+
+Our ``VideoClassificationLightningModule`` and ``KineticsDataModule`` are ready be trained together using the [``pytorch_lightning.Trainer``](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html)!. The trainer class has many arguments to define the training environment (e.g. num_gpus, distributed_backend). To keep things simple we'll just use the default local cpu training but note that this would likely take weeks to train so you might want to use more performant settings based on your environment.
+
+```python
+ def train():
+ classification_module = VideoClassificationLightningModule()
+ data_module = KineticsDataModule()
+ trainer = pytorch_lightning.Trainer()
+ trainer.fit(classification_module, data_module)
+```
+
+# Conclusion
+
+In this tutorial we showed how to train a 3D ResNet on Kinetics using PyTorch Lightning. You can see the final code from the tutorial (including a few extra bells and whistles) in the PyTorchVideo projects directory.
+
+To learn more about PyTorchVideo, check out the rest of the [documentation](https://pytorchvideo.readthedocs.io/en/latest/index.html) and [tutorials](https://pytorchvideo.org/docs/tutorial_overview).
diff --git a/code/pytorchvideo/website/docs/tutorial_overview.md b/code/pytorchvideo/website/docs/tutorial_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..288781a7a43feee352c1756a7d33831e8bae3a6d
--- /dev/null
+++ b/code/pytorchvideo/website/docs/tutorial_overview.md
@@ -0,0 +1,9 @@
+---
+id: tutorial_overview
+title: Tutorials
+sidebar_label: Overview
+---
+
+PyTorchVideo tutorials are designed to help you get acquainted with the library and also give you an idea on how to incorporate different PyTorchVideo components into your own video-research workflow. In the tutorials, through examples, we also show how PyTorchVideo makes it easy to address some of the common deeplearning video use cases.
+
+PyTorchVideo is built on PyTorch. If you are new to PyTorch, the easiest way to get started is with the [PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#sphx-glr-beginner-blitz-tensor-tutorial-py) tutorial.
diff --git a/code/pytorchvideo/website/docs/tutorial_torchhub_detection_inference.md b/code/pytorchvideo/website/docs/tutorial_torchhub_detection_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..3e5f6655bb35ebe933bbf7ea220c30bbbd2e661c
--- /dev/null
+++ b/code/pytorchvideo/website/docs/tutorial_torchhub_detection_inference.md
@@ -0,0 +1,245 @@
+---
+id: tutorial_torchhub_detection_inference
+title: Running a pre-trained PyTorchVideo classification model using Torch Hub
+---
+
+# Introduction
+
+PyTorchVideo provides several pretrained models through [Torch Hub](https://pytorch.org/hub/). In this tutorial we will show how to load a pre trained video classification model in PyTorchVideo and run it on a test video. The PyTorchVideo Torch Hub models were trained on the Kinetics 400 dataset and finetuned specifically for detection on AVA v2.2 dataset. Available models are described in [model zoo documentation](https://pytorchvideo.readthedocs.io/en/latest/model_zoo.html).
+
+NOTE: Currently, this tutorial only works if ran on local clone from the directory `pytorchvideo/tutorials/video_detection_example`
+
+This tutorial assumes that you have installed [Detectron2]((https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md)) and [Opencv-python](https://pypi.org/project/opencv-python/) on your machine.
+
+# Imports
+```python
+from functools import partial
+import numpy as np
+
+import cv2
+import torch
+
+import detectron2
+from detectron2.config import get_cfg
+from detectron2 import model_zoo
+from detectron2.engine import DefaultPredictor
+
+import pytorchvideo
+from pytorchvideo.transforms.functional import (
+ uniform_temporal_subsample,
+ short_side_scale_with_boxes,
+ clip_boxes_to_image,
+)
+from torchvision.transforms._functional_video import normalize
+from pytorchvideo.data.ava import AvaLabeledVideoFramePaths
+from pytorchvideo.models.hub import slow_r50_detection # Another option is slowfast_r50_detection
+
+from visualization import VideoVisualizer
+```
+
+# Load Model using Torch Hub API
+PyTorchVideo provides several pretrained models through Torch Hub. Available models are described in [model zoo documentation.](https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md)
+
+Here we are selecting the slow_r50_detection model which was trained using a 4x16 setting on the Kinetics 400 dataset and fine tuned on AVA V2.2 actions dataset.
+
+NOTE: to run on GPU in Google Colab, in the menu bar selet: Runtime -> Change runtime type -> Harware Accelerator -> GPU
+
+```python
+device = 'cuda' # or 'cpu'
+video_model = slow_r50_detection(True) # Another option is slowfast_r50_detection
+video_model = video_model.eval().to(device)
+```
+
+# Load an off-the-shelf Detectron2 object detector
+
+We use the object detector to detect bounding boxes for the people.
+These bounding boxes later feed into our video action detection model.
+For more details, please refer to the Detectron2's object detection tutorials.
+
+To install Detectron2, please follow the instructions mentioned [here](https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md)
+
+```python
+cfg = get_cfg()
+cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
+cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.55 # set threshold for this model
+cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
+predictor = DefaultPredictor(cfg)
+
+# This method takes in an image and generates the bounding boxes for people in the image.
+def get_person_bboxes(inp_img, predictor):
+ predictions = predictor(inp_img.cpu().detach().numpy())['instances'].to('cpu')
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
+ scores = predictions.scores if predictions.has("scores") else None
+ classes = np.array(predictions.pred_classes.tolist() if predictions.has("pred_classes") else None)
+ predicted_boxes = boxes[np.logical_and(classes==0, scores>0.75 )].tensor.cpu() # only person
+ return predicted_boxes
+```
+
+# Define the transformations for the input required by the model
+Before passing the video and bounding boxes into the model we need to apply some input transforms and sample a clip of the correct frame rate in the clip.
+
+Here, below we define a method that can pre-process the clip and bounding boxes. It generates inputs accordingly for both Slow (Resnet) and SlowFast models depending on the parameterization of the variable `slow_fast_alpha`.
+
+```python
+def ava_inference_transform(
+ clip,
+ boxes,
+ num_frames = 4, #if using slowfast_r50_detection, change this to 32
+ crop_size = 256,
+ data_mean = [0.45, 0.45, 0.45],
+ data_std = [0.225, 0.225, 0.225],
+ slow_fast_alpha = None, #if using slowfast_r50_detection, change this to 4
+):
+
+ boxes = np.array(boxes)
+ ori_boxes = boxes.copy()
+
+ # Image [0, 255] -> [0, 1].
+ clip = uniform_temporal_subsample(clip, num_frames)
+ clip = clip.float()
+ clip = clip / 255.0
+
+ height, width = clip.shape[2], clip.shape[3]
+ # The format of boxes is [x1, y1, x2, y2]. The input boxes are in the
+ # range of [0, width] for x and [0,height] for y
+ boxes = clip_boxes_to_image(boxes, height, width)
+
+ # Resize short side to crop_size. Non-local and STRG uses 256.
+ clip, boxes = short_side_scale_with_boxes(
+ clip,
+ size=crop_size,
+ boxes=boxes,
+ )
+
+ # Normalize images by mean and std.
+ clip = normalize(
+ clip,
+ np.array(data_mean, dtype=np.float32),
+ np.array(data_std, dtype=np.float32),
+ )
+
+ boxes = clip_boxes_to_image(
+ boxes, clip.shape[2], clip.shape[3]
+ )
+
+ # Incase of slowfast, generate both pathways
+ if slow_fast_alpha is not None:
+ fast_pathway = clip
+ # Perform temporal sampling from the fast pathway.
+ slow_pathway = torch.index_select(
+ clip,
+ 1,
+ torch.linspace(
+ 0, clip.shape[1] - 1, clip.shape[1] // slow_fast_alpha
+ ).long(),
+ )
+ clip = [slow_pathway, fast_pathway]
+
+ return clip, torch.from_numpy(boxes), ori_boxes
+```
+
+# Setup
+
+Download the id to label mapping for the AVA V2.2 dataset on which the Torch Hub models were finetuned.
+This will be used to get the category label names from the predicted class ids.
+
+Create a visualizer to visualize and plot the results(labels + bounding boxes).
+
+```python
+# Dowload the action text to id mapping
+!wget https://dl.fbaipublicfiles.com/pytorchvideo/data/class_names/ava_action_list.pbtxt
+
+# Create an id to label name mapping
+label_map, allowed_class_ids = AvaLabeledVideoFramePaths.read_label_map('ava_action_list.pbtxt')
+# Create a video visualizer that can plot bounding boxes and visualize actions on bboxes.
+video_visualizer = VideoVisualizer(81, label_map, top_k=3, mode="thres",thres=0.5)
+```
+
+# Load an example video
+We get an opensourced video off the web from WikiMedia.
+```python
+# Download the demo video.
+!wget https://dl.fbaipublicfiles.com/pytorchvideo/projects/theatre.webm
+
+# Load the video
+encoded_vid = pytorchvideo.data.encoded_video.EncodedVideo.from_path('theatre.webm')
+print('Completed loading encoded video.')
+```
+
+# Get model predictions
+
+Generate bounding boxes and action predictions for a 10 second clip in the video.
+
+```python
+# Video predictions are generated at an internal of 1 sec from 90 seconds to 100 seconds in the video.
+time_stamp_range = range(90,100) # time stamps in video for which clip is sampled.
+clip_duration = 1.0 # Duration of clip used for each inference step.
+gif_imgs = []
+
+for time_stamp in time_stamp_range:
+ print("Generating predictions for time stamp: {} sec".format(time_stamp))
+
+ # Generate clip around the designated time stamps
+ inp_imgs = encoded_vid.get_clip(
+ time_stamp - clip_duration/2.0, # start second
+ time_stamp + clip_duration/2.0 # end second
+ )
+ inp_imgs = inp_imgs['video']
+
+ # Generate people bbox predictions using Detectron2's off the self pre-trained predictor
+ # We use the the middle image in each clip to generate the bounding boxes.
+ inp_img = inp_imgs[:,inp_imgs.shape[1]//2,:,:]
+ inp_img = inp_img.permute(1,2,0)
+
+ # Predicted boxes are of the form List[(x_1, y_1, x_2, y_2)]
+ predicted_boxes = get_person_bboxes(inp_img, predictor)
+ if len(predicted_boxes) == 0:
+ print("Skipping clip no frames detected at time stamp: ", time_stamp)
+ continue
+
+ # Preprocess clip and bounding boxes for video action recognition.
+ inputs, inp_boxes, _ = ava_inference_transform(inp_imgs, predicted_boxes.numpy())
+ # Prepend data sample id for each bounding box.
+ # For more details refere to the RoIAlign in Detectron2
+ inp_boxes = torch.cat([torch.zeros(inp_boxes.shape[0],1), inp_boxes], dim=1)
+
+ # Generate actions predictions for the bounding boxes in the clip.
+ # The model here takes in the pre-processed video clip and the detected bounding boxes.
+ preds = video_model(inputs.unsqueeze(0).to(device), inp_boxes.to(device))
+
+
+ preds= preds.to('cpu')
+ # The model is trained on AVA and AVA labels are 1 indexed so, prepend 0 to convert to 0 index.
+ preds = torch.cat([torch.zeros(preds.shape[0],1), preds], dim=1)
+
+ # Plot predictions on the video and save for later visualization.
+ inp_imgs = inp_imgs.permute(1,2,3,0)
+ inp_imgs = inp_imgs/255.0
+ out_img_pred = video_visualizer.draw_clip_range(inp_imgs, preds, predicted_boxes)
+ gif_imgs += out_img_pred
+
+print("Finished generating predictions.")
+```
+
+We now save the predicted video containing bounding boxes and action labels for the bounding boxes.
+
+```python
+height, width = gif_imgs[0].shape[0], gif_imgs[0].shape[1]
+
+vide_save_path = 'output.mp4'
+video = cv2.VideoWriter(vide_save_path,cv2.VideoWriter_fourcc(*'DIVX'), 7, (width,height))
+
+for image in gif_imgs:
+ img = (255*image).astype(np.uint8)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ video.write(img)
+video.release()
+
+print('Predictions are saved to the video file: ', vide_save_path)
+```
+
+# Conclusion
+
+In this tutorial we showed how to load and run a pretrained PyTorchVideo detection model on a test video. You can run this tutorial as a notebook in the PyTorchVideo tutorials directory.
+
+To learn more about PyTorchVideo, check out the rest of the [documentation](https://pytorchvideo.readthedocs.io/en/latest/index.html) and [tutorials](https://pytorchvideo.org/docs/tutorial_overview).
diff --git a/code/pytorchvideo/website/docs/tutorial_torchhub_inference.md b/code/pytorchvideo/website/docs/tutorial_torchhub_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..3805e2466140ea12b528254b97abad9d708d9536
--- /dev/null
+++ b/code/pytorchvideo/website/docs/tutorial_torchhub_inference.md
@@ -0,0 +1,193 @@
+---
+id: tutorial_torchhub_inference
+title: Running a pre-trained PyTorchVideo classification model using Torch Hub
+---
+
+# Introduction
+
+PyTorchVideo provides several pretrained models through [Torch Hub](https://pytorch.org/hub/). In this tutorial we will show how to load a pre trained video classification model in PyTorchVideo and run it on a test video. The PyTorchVideo Torch Hub models were trained on the Kinetics 400 [1] dataset. Available models are described in [model zoo documentation](https://pytorchvideo.readthedocs.io/en/latest/model_zoo.html).
+
+[1] W. Kay, et al. The kinetics human action video dataset. arXiv preprint arXiv:1705.06950, 2017.
+
+
+# Imports
+
+```python
+import torch
+import json
+from torchvision.transforms import Compose, Lambda
+from torchvision.transforms._transforms_video import (
+ CenterCropVideo,
+ NormalizeVideo,
+)
+from pytorchvideo.data.encoded_video import EncodedVideo
+from pytorchvideo.transforms import (
+ ApplyTransformToKey,
+ ShortSideScale,
+ UniformTemporalSubsample,
+ UniformCropVideo
+)
+from typing import Dict
+```
+
+# Load Model
+
+Let's select the `slowfast_r50` model which was trained on the Kinetics 400 dataset.
+
+```python
+# Device on which to run the model
+# Set to cuda to load on GPU
+device = "cpu"
+
+# Pick a pretrained model and load the pretrained weights
+model_name = "slowfast_r50"
+model = torch.hub.load("facebookresearch/pytorchvideo:main", model=model_name, pretrained=True)
+
+# Set to eval mode and move to desired device
+model = model.to(device)
+model = model.eval()
+```
+
+# Setup Labels
+
+Next let's download the id-to-label mapping for the Kinetics 400 dataset on which the torch hub models were trained. This will be used to get the category label names from the predicted class ids.
+
+```python
+!wget https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json
+```
+
+```python
+with open("kinetics_classnames.json", "r") as f:
+ kinetics_classnames = json.load(f)
+
+# Create an id to label name mapping
+kinetics_id_to_classname = {}
+for k, v in kinetics_classnames.items():
+ kinetics_id_to_classname[v] = str(k).replace('"', "")
+```
+
+# Input Transform
+
+Before passing the video into the model we need to apply some input transforms and sample a clip of the correct duration.
+
+NOTE: The input transforms are specific to the model. If you choose a different model than the example in this tutorial, please refer to the code provided in the Torch Hub documentation and copy over the relevant transforms:
+
+ - [SlowFast](https://pytorch.org/hub/facebookresearch_pytorchvideo_slowfast/)
+ - [X3D](https://pytorch.org/hub/facebookresearch_pytorchvideo_x3d/)
+ - [Slow](https://pytorch.org/hub/facebookresearch_pytorchvideo_resnet/)
+
+```python
+####################
+# SlowFast transform
+####################
+
+side_size = 256
+mean = [0.45, 0.45, 0.45]
+std = [0.225, 0.225, 0.225]
+crop_size = 256
+num_frames = 32
+sampling_rate = 2
+frames_per_second = 30
+alpha = 4
+
+class PackPathway(torch.nn.Module):
+ """
+ Transform for converting video frames as a list of tensors.
+ """
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, frames: torch.Tensor):
+ fast_pathway = frames
+ # Perform temporal sampling from the fast pathway.
+ slow_pathway = torch.index_select(
+ frames,
+ 1,
+ torch.linspace(
+ 0, frames.shape[1] - 1, frames.shape[1] // alpha
+ ).long(),
+ )
+ frame_list = [slow_pathway, fast_pathway]
+ return frame_list
+
+transform = ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(num_frames),
+ Lambda(lambda x: x/255.0),
+ NormalizeVideo(mean, std),
+ ShortSideScale(
+ size=side_size
+ ),
+ CenterCropVideo(crop_size),
+ PackPathway()
+ ]
+ ),
+)
+
+# The duration of the input clip is also specific to the model.
+clip_duration = (num_frames * sampling_rate)/frames_per_second
+```
+
+# Load an example video
+We can now test the model with an example video from the Kinetics validation set such as this [archery video](https://www.youtube.com/watch?v=3and4vWkW4s).
+
+We will load the video and apply the input transform.
+
+
+```python
+# Download the example video file
+!wget https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4
+```
+
+```python
+# Load the example video
+video_path = "archery.mp4"
+
+# Select the duration of the clip to load by specifying the start and end duration
+# The start_sec should correspond to where the action occurs in the video
+start_sec = 0
+end_sec = start_sec + clip_duration
+
+# Initialize an EncodedVideo helper class
+video = EncodedVideo.from_path(video_path)
+
+# Load the desired clip
+video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
+
+# Apply a transform to normalize the video input
+video_data = transform(video_data)
+
+# Move the inputs to the desired device
+inputs = video_data["video"]
+inputs = [i.to(device)[None, ...] for i in inputs]
+```
+
+### Get model predictions
+
+Now we are ready to pass the input into the model and classify the action.
+
+```python
+# Pass the input clip through the model
+preds = model(inputs)
+```
+
+Let's look at the top 5 best predictions:
+
+```python
+# Get the predicted classes
+post_act = torch.nn.Softmax(dim=1)
+preds = post_act(preds)
+pred_classes = preds.topk(k=5).indices
+
+# Map the predicted classes to the label names
+pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes[0]]
+print("Predicted labels: %s" % ", ".join(pred_class_names))
+```
+
+# Conclusion
+
+In this tutorial we showed how to load and run a pretrained PyTorchVideo model on a test video. You can run this tutorial as a notebook in the PyTorchVideo tutorials directory.
+
+To learn more about PyTorchVideo, check out the rest of the [documentation](https://pytorchvideo.readthedocs.io/en/latest/index.html) and [tutorials](https://pytorchvideo.org/docs/tutorial_overview).
diff --git a/code/pytorchvideo/website/website/README.md b/code/pytorchvideo/website/website/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2e7802e3a08ae45ad2895aee27c1630c1416a96f
--- /dev/null
+++ b/code/pytorchvideo/website/website/README.md
@@ -0,0 +1,216 @@
+This website was created with [Docusaurus](https://docusaurus.io/).
+
+# Building the PyTorchVideo website
+
+## Install
+
+1. Make sure all the dependencies for the website are installed:
+
+```sh
+# Install dependencies
+$ yarn
+
+or
+
+$ npm install docusaurus-init
+```
+
+2. Run your dev server:
+
+```sh
+# Start the site
+$ yarn start
+
+or
+$ ./node_modules/docusaurus/lib/start-server.js
+```
+
+
+## Edit the landing page
+
+To change the content of the landing page modify: `website/pages/en/index.js`.
+
+
+---------------------------------------------------------
+
+## Docusaurus docs
+
+- [Get Started in 5 Minutes](#get-started-in-5-minutes)
+- [Directory Structure](#directory-structure)
+- [Editing Content](#editing-content)
+- [Adding Content](#adding-content)
+- [Full Documentation](#full-documentation)
+
+
+## Directory Structure
+
+Your project file structure should look something like this
+
+```
+my-docusaurus/
+ docs/
+ doc-1.md
+ doc-2.md
+ doc-3.md
+ website/
+ blog/
+ 2016-3-11-oldest-post.md
+ 2017-10-24-newest-post.md
+ core/
+ node_modules/
+ pages/
+ static/
+ css/
+ img/
+ package.json
+ sidebars.json
+ siteConfig.js
+```
+
+# Editing Content
+
+## Editing an existing docs page
+
+Edit docs by navigating to `docs/` and editing the corresponding document:
+
+`docs/doc-to-be-edited.md`
+
+```markdown
+---
+id: page-needs-edit
+title: This Doc Needs To Be Edited
+---
+
+Edit me...
+```
+
+For more information about docs, click [here](https://docusaurus.io/docs/en/navigation)
+
+## Editing an existing blog post
+
+Edit blog posts by navigating to `website/blog` and editing the corresponding post:
+
+`website/blog/post-to-be-edited.md`
+
+```markdown
+---
+id: post-needs-edit
+title: This Blog Post Needs To Be Edited
+---
+
+Edit me...
+```
+
+For more information about blog posts, click [here](https://docusaurus.io/docs/en/adding-blog)
+
+# Adding Content
+
+## Adding a new docs page to an existing sidebar
+
+1. Create the doc as a new markdown file in `/docs`, example `docs/newly-created-doc.md`:
+
+```md
+---
+id: newly-created-doc
+title: This Doc Needs To Be Edited
+---
+
+My new content here..
+```
+
+1. Refer to that doc's ID in an existing sidebar in `website/sidebars.json`:
+
+```javascript
+// Add newly-created-doc to the Getting Started category of docs
+{
+ "docs": {
+ "Getting Started": [
+ "quick-start",
+ "newly-created-doc" // new doc here
+ ],
+ ...
+ },
+ ...
+}
+```
+
+For more information about adding new docs, click [here](https://docusaurus.io/docs/en/navigation)
+
+## Adding a new blog post
+
+1. Make sure there is a header link to your blog in `website/siteConfig.js`:
+
+`website/siteConfig.js`
+
+```javascript
+headerLinks: [
+ ...
+ { blog: true, label: 'Blog' },
+ ...
+]
+```
+
+2. Create the blog post with the format `YYYY-MM-DD-My-Blog-Post-Title.md` in `website/blog`:
+
+`website/blog/2018-05-21-New-Blog-Post.md`
+
+```markdown
+---
+author: Frank Li
+authorURL: https://twitter.com/foobarbaz
+authorFBID: 503283835
+title: New Blog Post
+---
+
+Lorem Ipsum...
+```
+
+For more information about blog posts, click [here](https://docusaurus.io/docs/en/adding-blog)
+
+## Adding items to your site's top navigation bar
+
+1. Add links to docs, custom pages or external links by editing the headerLinks field of `website/siteConfig.js`:
+
+`website/siteConfig.js`
+
+```javascript
+{
+ headerLinks: [
+ ...
+ /* you can add docs */
+ { doc: 'my-examples', label: 'Examples' },
+ /* you can add custom pages */
+ { page: 'help', label: 'Help' },
+ /* you can add external links */
+ { href: 'https://github.com/facebook/docusaurus', label: 'GitHub' },
+ ...
+ ],
+ ...
+}
+```
+
+For more information about the navigation bar, click [here](https://docusaurus.io/docs/en/navigation)
+
+## Adding custom pages
+
+1. Docusaurus uses React components to build pages. The components are saved as .js files in `website/pages/en`:
+1. If you want your page to show up in your navigation header, you will need to update `website/siteConfig.js` to add to the `headerLinks` element:
+
+`website/siteConfig.js`
+
+```javascript
+{
+ headerLinks: [
+ ...
+ { page: 'my-new-custom-page', label: 'My New Custom Page' },
+ ...
+ ],
+ ...
+}
+```
+
+For more information about custom pages, click [here](https://docusaurus.io/docs/en/custom-pages).
+
+# Full Documentation
+
+Full documentation can be found on the [website](https://docusaurus.io/).
diff --git a/code/pytorchvideo/website/website/core/Footer.js b/code/pytorchvideo/website/website/core/Footer.js
new file mode 100644
index 0000000000000000000000000000000000000000..ee5fc65e20391317129bbda01f42f00f48b98913
--- /dev/null
+++ b/code/pytorchvideo/website/website/core/Footer.js
@@ -0,0 +1,91 @@
+/**
+ * Copyright (c) 2017-present, Facebook, Inc.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+
+const PropTypes = require("prop-types");
+const React = require('react');
+
+function SocialFooter(props) {
+ const repoUrl = `https://github.com/${props.config.organizationName}/${props.config.projectName}`;
+ return (
+
+ );
+}
+
+SocialFooter.propTypes = {
+ config: PropTypes.object
+};
+
+class Footer extends React.Component {
+ docUrl(doc, language) {
+ const baseUrl = this.props.config.baseUrl;
+ const docsUrl = this.props.config.docsUrl;
+ const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`;
+ const langPart = `${language ? `${language}/` : ''}`;
+ return `${baseUrl}${docsPart}${langPart}${doc}`;
+ }
+
+ pageUrl(doc, language) {
+ const baseUrl = this.props.config.baseUrl;
+ return baseUrl + (language ? `${language}/` : '') + doc;
+ }
+
+ render() {
+ const repoUrl = `https://github.com/${this.props.config.organizationName}/${this.props.config.projectName}`;
+ return (
+
+ );
+ }
+}
+
+module.exports = Footer;
\ No newline at end of file
diff --git a/code/pytorchvideo/website/website/package.json b/code/pytorchvideo/website/website/package.json
new file mode 100644
index 0000000000000000000000000000000000000000..a92c8b6ee5189970d3fee7cce4f3a1b226e39039
--- /dev/null
+++ b/code/pytorchvideo/website/website/package.json
@@ -0,0 +1,14 @@
+{
+ "scripts": {
+ "examples": "docusaurus-examples",
+ "start": "docusaurus-start",
+ "build": "docusaurus-build",
+ "publish-gh-pages": "docusaurus-publish",
+ "write-translations": "docusaurus-write-translations",
+ "version": "docusaurus-version",
+ "rename-version": "docusaurus-rename-version"
+ },
+ "devDependencies": {
+ "docusaurus": "^1.14.6"
+ }
+}
diff --git a/code/pytorchvideo/website/website/pages/en/index.js b/code/pytorchvideo/website/website/pages/en/index.js
new file mode 100644
index 0000000000000000000000000000000000000000..58fb13510046b432a53aec5342c32fb0161b59d7
--- /dev/null
+++ b/code/pytorchvideo/website/website/pages/en/index.js
@@ -0,0 +1,270 @@
+/**
+ * Copyright (c) 2021-present, Facebook, Inc.
+**/
+
+const React = require('react');
+
+const CompLibrary = require('../../core/CompLibrary.js');
+
+const MarkdownBlock = CompLibrary.MarkdownBlock; /* Used to read markdown */
+const Container = CompLibrary.Container;
+const GridBlock = CompLibrary.GridBlock;
+const bash = (...args) => `~~~bash\n${String.raw(...args)}\n~~~`;
+class HomeSplash extends React.Component {
+ render() {
+ const {siteConfig, language = ''} = this.props;
+ const {baseUrl, docsUrl} = siteConfig;
+ const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`;
+ const langPart = `${language ? `${language}/` : ''}`;
+ const docUrl = doc => `${baseUrl}${docsPart}${langPart}${doc}`;
+
+ const SplashContainer = props => (
+
+ );
+
+ const Logo = props => (
+
+
+
+ );
+
+ const ProjectTitle = props => (
+
+ {props.tagline}
+
+ );
+
+ const PromoSection = props => (
+
+ );
+
+ const Button = props => (
+
+ );
+
+ return (
+
+
+
+
+
+ Get Started
+ Tutorials
+ GitHub
+
+
+
+ );
+ }
+}
+
+class Index extends React.Component {
+ render() {
+ const {config: siteConfig, language = ''} = this.props;
+ const {baseUrl} = siteConfig;
+
+ const Block = props => (
+
+
+
+ );
+
+ const Description = () => (
+
+ {[
+ {
+ content:
+ 'This is another description of how this project is useful',
+ image: `${baseUrl}img/placeholder.png`,
+ imageAlign: 'right',
+ title: 'Description',
+ },
+ ]}
+
+ );
+
+ const pre = '```';
+
+ const codeExample = `${pre}python
+# Import all the required components
+...
+
+# Load pre-trained model
+model = torch.hub.load('facebookresearch/pytorchvideo:main', 'slow_r50', pretrained=True)
+
+# Load video
+video = EncodedVideo.from_path('some_video.avi')
+
+# Compose video data transforms
+transform = ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(num_frames),
+ Lambda(lambda x: x/255.0),
+ NormalizeVideo(mean, std),
+ ShortSideScale(
+ size=side_size
+ ),
+ CenterCropVideo(crop_size=(crop_size, crop_size))
+ ]
+ ),
+)
+
+# Get clip
+clip_start_sec = 0.0 # secs
+clip_duration = 2.0 # secs
+video_data = video.get_clip(start_sec=clip_start_sec, end_sec=clip_start_sec + clip_duration)
+video_data = transform(video_data)
+
+# Generate top 5 predictions
+preds = torch.nn.functional.softmax(preds)
+pred_class_ids = preds.topk(k=5).indices
+ `;
+ const install = `${pre}bash
+pip install pytorchvideo
+ `;
+
+ const QuickStart = () => (
+
+
Get Started
+
+
+
+ Install pytorchvideo (Confirm requirements following the instructions here )
+ {install}
+
+
+ Try Video classification with Model Zoo
+ (For detailed instructions, refer to the PyTorchVideo Model Zoo Inference Tutorial
+ {codeExample}
+
+
+
+
+ );
+
+ const UseCases = () => (
+
+
Some use cases
+
+
+
+
+
Detection (Add GIF)
+
+
+
+
+
Tracking (Add GIF)
+
+
+
+
+
Classification (Add GIF)
+
+
+
+ );
+
+ const Features = () => (
+
+
+ {[
+ {
+ content:
+ 'Built using PyTorch. Makes it easy to use all the PyTorch-ecosystem components.',
+ image: `${baseUrl}img/pytorch.svg`,
+ imageAlign: 'top',
+ title: 'Based on PyTorch',
+ },
+ {
+ content:
+ 'Variety of state of the art pretrained video models and their associated benchmarks that are ready to use.',
+ image: `${baseUrl}img/modelzoo.svg`,
+ imageAlign: 'top',
+ title: 'Reproducible Model Zoo',
+ },
+ // {
+ // content:
+ // 'Variety of benchmark tasks available to evaluate the models.',
+ // image: `${baseUrl}img/reproducible.svg`,
+ // imageAlign: 'top',
+ // title: 'Reproducible Benchmarks',
+ // },
+ {
+ content:
+ 'Video-focused fast and efficient components that are easy to use. Supports accelerated inference on hardware.',
+ image: `${baseUrl}img/efficient.svg`,
+ imageAlign: 'top',
+ title: 'Efficient Video Components',
+ },
+ ]}
+
+
+ );
+
+ const Showcase = () => {
+ if ((siteConfig.users || []).length === 0) {
+ return null;
+ }
+
+ const showcase = siteConfig.users
+ .filter(user => user.pinned)
+ .map(user => (
+
+
+
+ ));
+
+ const pageUrl = page => baseUrl + (language ? `${language}/` : '') + page;
+
+ return (
+
+
Who is Using This?
+
This project is used by all these people
+
{showcase}
+
+
+ );
+ };
+
+ return (
+
+ );
+ }
+}
+
+module.exports = Index;
diff --git a/code/pytorchvideo/website/website/sidebars.json b/code/pytorchvideo/website/website/sidebars.json
new file mode 100644
index 0000000000000000000000000000000000000000..756d284dd3814c4c111cfb897becb541fc6180fa
--- /dev/null
+++ b/code/pytorchvideo/website/website/sidebars.json
@@ -0,0 +1,8 @@
+{
+ "docs-other": {
+ "Tutorials": ["tutorial_overview"],
+ "Classification": ["tutorial_classification", "tutorial_torchhub_inference"],
+ "Detection": ["tutorial_torchhub_detection_inference"],
+ "Accelerator": ["tutorial_accelerator_build_your_model", "tutorial_accelerator_use_accelerator_model_zoo", "tutorial_accelerator_use_model_transmuter"]
+ }
+}
diff --git a/code/pytorchvideo/website/website/siteConfig.js b/code/pytorchvideo/website/website/siteConfig.js
new file mode 100644
index 0000000000000000000000000000000000000000..ea8ca2331d677e83fe513d75e378315e3923e3fe
--- /dev/null
+++ b/code/pytorchvideo/website/website/siteConfig.js
@@ -0,0 +1,66 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+// See https://docusaurus.io/docs/site-config for all the possible
+// site configuration options.
+
+
+const siteConfig = {
+ title: 'PyTorchVideo', // Title for your website.
+ tagline: 'A deep learning library for video understanding research',
+ url: 'https://pytorchvideo.org', // Your website URL
+ baseUrl: '/',
+
+ // Used for publishing and more
+ projectName: 'pytorchvideo',
+ organizationName: 'facebookresearch',
+
+ // For no header links in the top nav bar -> headerLinks: [],
+ headerLinks: [
+ {doc: 'tutorial_overview', label: 'Tutorials'},
+ {href: "https://pytorchvideo.readthedocs.io/en/latest/index.html", label: 'Docs'}, // TODO: Change this after the repo becomes public.
+ {href: "https://github.com/facebookresearch/pytorchvideo/", label: 'GitHub'}, //TODO: Change this after repo becomes public
+ ],
+
+
+ /* path to images for header/footer */
+ headerIcon: 'img/logo.svg',
+ footerIcon: 'img/logo.svg',
+ favicon: 'img/favicon.png',
+
+ /* Colors for website */
+ colors: {
+ primaryColor: '#812ce5',
+ secondaryColor: '#cc33cc',
+ },
+
+ // This copyright info is used in /core/Footer.js and blog RSS/Atom feeds.
+ copyright: `Copyright © ${new Date().getFullYear()} Facebook, Inc`,
+
+ highlight: {
+ // Highlight.js theme to use for syntax highlighting in code blocks.
+ theme: 'atom-one-dark',
+ },
+
+ // Add custom scripts here that would be placed in