CHEN11102 commited on
Commit
708d62c
·
verified ·
1 Parent(s): 3dcc2fb

Upload 47 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ moment.gif filter=lfs diff=lfs merge=lfs -text
36
+ photos/one.png filter=lfs diff=lfs merge=lfs -text
37
+ photos/two.png filter=lfs diff=lfs merge=lfs -text
CONTRIBUTING.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ We'd love to accept your patches and contributions to this project. There are
4
+ just a few small guidelines you need to follow.
5
+
6
+ ## Contributor License Agreement
7
+
8
+ Contributions to this project must be accompanied by a Contributor License
9
+ Agreement (CLA). You (or your employer) retain the copyright to your
10
+ contribution; this simply gives us permission to use and redistribute your
11
+ contributions as part of the project. Head over to
12
+ <https://cla.developers.google.com/> to see your current agreements on file or
13
+ to sign a new one.
14
+
15
+ You generally only need to submit a CLA once, so if you've already submitted one
16
+ (even if it was for a different project), you probably don't need to do it
17
+ again.
18
+
19
+ ## Code Reviews
20
+
21
+ All submissions, including submissions by project members, require review. We
22
+ use GitHub pull requests for this purpose. Consult
23
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24
+ information on using pull requests.
25
+
26
+ ## Community Guidelines
27
+
28
+ This project follows
29
+ [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,20 +1,276 @@
1
- ---
2
- title: '1'
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
 
11
- This is a templated Space for [Shiny for Python](https://shiny.rstudio.com/py/).
12
 
 
13
 
14
- To get started with a new app do the following:
 
 
 
15
 
16
- 1) Install Shiny with `pip install shiny`
17
- 2) Create a new app with `shiny create .`
18
- 3) Then run the app with `shiny run --reload`
19
 
20
- To learn more about this framework please see the [Documentation](https://shiny.rstudio.com/py/docs/overview.html).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILM: Frame Interpolation for Large Motion
 
 
 
 
 
 
 
 
2
 
3
+ ### [Website](https://film-net.github.io/) | [Paper](https://arxiv.org/pdf/2202.04901.pdf) | [Google AI Blog](https://ai.googleblog.com/2022/10/large-motion-frame-interpolation.html) | [Tensorflow Hub Colab](https://www.tensorflow.org/hub/tutorials/tf_hub_film_example) | [YouTube](https://www.youtube.com/watch?v=OAD-BieIjH4) <br>
4
 
5
+ The official Tensorflow 2 implementation of our high quality frame interpolation neural network. We present a unified single-network approach that doesn't use additional pre-trained networks, like optical flow or depth, and yet achieve state-of-the-art results. We use a multi-scale feature extractor that shares the same convolution weights across the scales. Our model is trainable from frame triplets alone. <br>
6
 
7
+ [FILM: Frame Interpolation for Large Motion](https://arxiv.org/abs/2202.04901) <br />
8
+ [Fitsum Reda](https://fitsumreda.github.io/)<sup>1</sup>, [Janne Kontkanen](https://scholar.google.com/citations?user=MnXc4JQAAAAJ&hl=en)<sup>1</sup>, [Eric Tabellion](http://www.tabellion.org/et/)<sup>1</sup>, [Deqing Sun](https://deqings.github.io/)<sup>1</sup>, [Caroline Pantofaru](https://scholar.google.com/citations?user=vKAKE1gAAAAJ&hl=en)<sup>1</sup>, [Brian Curless](https://homes.cs.washington.edu/~curless/)<sup>1,2</sup><br />
9
+ <sup>1</sup>Google Research, <sup>2</sup>University of Washington<br />
10
+ In ECCV 2022.
11
 
12
+ ![A sample 2 seconds moment.](https://github.com/googlestaging/frame-interpolation/blob/main/moment.gif)
13
+ FILM transforms near-duplicate photos into a slow motion footage that look like it is shot with a video camera.
 
14
 
15
+ ## Web Demo
16
+
17
+ Integrated into [Hugging Face Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/johngoad/frame-interpolation)
18
+
19
+ Try the interpolation model with the replicate web demo at
20
+ [![Replicate](https://replicate.com/google-research/frame-interpolation/badge)](https://replicate.com/google-research/frame-interpolation)
21
+
22
+ Try FILM to interpolate between two or more images with the PyTTI-Tools at [![PyTTI-Tools:FILM](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/pytti-tools/frame-interpolation/blob/main/PyTTI_Tools_FiLM-colab.ipynb#scrollTo=-7TD7YZJbsy_)
23
+
24
+ An alternative Colab for running FILM on arbitrarily more input images, not just on two images, [![FILM-Gdrive](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1NuaPPSvUhYafymUf2mEkvhnEtpD5oihs)
25
+
26
+ ## Change Log
27
+ * **Nov 28, 2022**: Upgrade `eval.interpolator_cli` for **high resolution frame interpolation**. `--block_height` and `--block_width` determine the total number of patches (`block_height*block_width`) to subdivide the input images. By default, both arguments are set to 1, and so no subdivision will be done.
28
+ * **Mar 12, 2022**: Support for Windows, see [WINDOWS_INSTALLATION.md](https://github.com/google-research/frame-interpolation/blob/main/WINDOWS_INSTALLATION.md).
29
+ * **Mar 09, 2022**: Support for **high resolution frame interpolation**. Set `--block_height` and `--block_width` in `eval.interpolator_test` to extract patches from the inputs, and reconstruct the interpolated frame from the iteratively interpolated patches.
30
+
31
+ ## Installation
32
+
33
+ * Get Frame Interpolation source codes
34
+
35
+ ```
36
+ git clone https://github.com/google-research/frame-interpolation
37
+ cd frame-interpolation
38
+ ```
39
+
40
+ * Optionally, pull the recommended Docker base image
41
+
42
+ ```
43
+ docker pull gcr.io/deeplearning-platform-release/tf2-gpu.2-6:latest
44
+ ```
45
+
46
+ * If you do not use Docker, set up your NVIDIA GPU environment with:
47
+ * [Anaconda Python 3.9](https://www.anaconda.com/products/individual)
48
+ * [CUDA Toolkit 11.2.1](https://developer.nvidia.com/cuda-11.2.1-download-archive)
49
+ * [cuDNN 8.1.0](https://developer.nvidia.com/rdp/cudnn-download)
50
+
51
+ * Install frame interpolation dependencies
52
+
53
+ ```
54
+ pip3 install -r requirements.txt
55
+ sudo apt-get install -y ffmpeg
56
+ ```
57
+
58
+ ### See [WINDOWS_INSTALLATION](https://github.com/google-research/frame-interpolation/blob/main/WINDOWS_INSTALLATION.md) for Windows Support
59
+
60
+ ## Pre-trained Models
61
+
62
+ * Create a directory where you can keep large files. Ideally, not in this
63
+ directory.
64
+
65
+ ```
66
+ mkdir -p <pretrained_models>
67
+ ```
68
+
69
+ * Download pre-trained TF2 Saved Models from
70
+ [google drive](https://drive.google.com/drive/folders/1q8110-qp225asX3DQvZnfLfJPkCHmDpy?usp=sharing)
71
+ and put into `<pretrained_models>`.
72
+
73
+ The downloaded folder should have the following structure:
74
+
75
+ ```
76
+ <pretrained_models>/
77
+ ├── film_net/
78
+ │ ├── L1/
79
+ │ ├── Style/
80
+ │ ├── VGG/
81
+ ├── vgg/
82
+ │ ├── imagenet-vgg-verydeep-19.mat
83
+ ```
84
+
85
+ ## Running the Codes
86
+
87
+ The following instructions run the interpolator on the photos provided in
88
+ 'frame-interpolation/photos'.
89
+
90
+ ### One mid-frame interpolation
91
+
92
+ To generate an intermediate photo from the input near-duplicate photos, simply run:
93
+
94
+ ```
95
+ python3 -m eval.interpolator_test \
96
+ --frame1 photos/one.png \
97
+ --frame2 photos/two.png \
98
+ --model_path <pretrained_models>/film_net/Style/saved_model \
99
+ --output_frame photos/output_middle.png
100
+ ```
101
+
102
+ This will produce the sub-frame at `t=0.5` and save as 'photos/output_middle.png'.
103
+
104
+ ### Many in-between frames interpolation
105
+
106
+ It takes in a set of directories identified by a glob (--pattern). Each directory
107
+ is expected to contain at least two input frames, with each contiguous frame
108
+ pair treated as an input to generate in-between frames. Frames should be named such that when sorted (naturally) with `natsort`, their desired order is unchanged.
109
+
110
+ ```
111
+ python3 -m eval.interpolator_cli \
112
+ --pattern "photos" \
113
+ --model_path <pretrained_models>/film_net/Style/saved_model \
114
+ --times_to_interpolate 6 \
115
+ --output_video
116
+ ```
117
+
118
+ You will find the interpolated frames (including the input frames) in
119
+ 'photos/interpolated_frames/', and the interpolated video at
120
+ 'photos/interpolated.mp4'.
121
+
122
+ The number of frames is determined by `--times_to_interpolate`, which controls
123
+ the number of times the frame interpolator is invoked. When the number of frames
124
+ in a directory is `num_frames`, the number of output frames will be
125
+ `(2^times_to_interpolate+1)*(num_frames-1)`.
126
+
127
+ ## Datasets
128
+
129
+ We use [Vimeo-90K](http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip) as
130
+ our main training dataset. For quantitative evaluations, we rely on commonly
131
+ used benchmark datasets, specifically:
132
+
133
+ * [Vimeo-90K](http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip)
134
+ * [Middlebury-Other](https://vision.middlebury.edu/flow/data)
135
+ * [UCF101](https://people.cs.umass.edu/~hzjiang/projects/superslomo/UCF101_results.zip)
136
+ * [Xiph](https://github.com/sniklaus/softmax-splatting/blob/master/benchmark.py)
137
+
138
+ ### Creating a TFRecord
139
+
140
+ The training and benchmark evaluation scripts expect the frame triplets in the
141
+ [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) storage format. <br />
142
+
143
+ We have included scripts that encode the relevant frame triplets into a
144
+ [tf.train.Example](https://www.tensorflow.org/api_docs/python/tf/train/Example)
145
+ data format, and export to a TFRecord file. <br />
146
+
147
+ You can use the commands `python3 -m
148
+ datasets.create_<dataset_name>_tfrecord --help` for more information.
149
+
150
+ For example, run the command below to create a TFRecord for the Middlebury-other
151
+ dataset. Download the [images](https://vision.middlebury.edu/flow/data) and point `--input_dir` to the unzipped folder path.
152
+
153
+ ```
154
+ python3 -m datasets.create_middlebury_tfrecord \
155
+ --input_dir=<root folder of middlebury-other> \
156
+ --output_tfrecord_filepath=<output tfrecord filepath> \
157
+ --num_shards=3
158
+ ```
159
+
160
+ The above command will output a TFRecord file with 3 shards as `<output tfrecord filepath>@3`.
161
+
162
+ ## Training
163
+
164
+ Below are our training gin configuration files for the different loss function:
165
+
166
+ ```
167
+ training/
168
+ ├── config/
169
+ │ ├── film_net-L1.gin
170
+ │ ├── film_net-VGG.gin
171
+ │ ├── film_net-Style.gin
172
+ ```
173
+
174
+ To launch a training, simply pass the configuration filepath to the desired
175
+ experiment. <br />
176
+ By default, it uses all visible GPUs for training. To debug or train
177
+ on a CPU, append `--mode cpu`.
178
+
179
+ ```
180
+ python3 -m training.train \
181
+ --gin_config training/config/<config filename>.gin \
182
+ --base_folder <base folder for all training runs> \
183
+ --label <descriptive label for the run>
184
+ ```
185
+
186
+ * When training finishes, the folder structure will look like this:
187
+
188
+ ```
189
+ <base_folder>/
190
+ ├── <label>/
191
+ │ ├── config.gin
192
+ │ ├── eval/
193
+ │ ├── train/
194
+ │ ├── saved_model/
195
+ ```
196
+
197
+ ### Build a SavedModel
198
+
199
+ Optionally, to build a
200
+ [SavedModel](https://www.tensorflow.org/guide/saved_model) format from a trained
201
+ checkpoints folder, you can use this command:
202
+
203
+ ```
204
+ python3 -m training.build_saved_model_cli \
205
+ --base_folder <base folder of training sessions> \
206
+ --label <the name of the run>
207
+ ```
208
+
209
+ * By default, a SavedModel is created when the training loop ends, and it will be saved at
210
+ `<base_folder>/<label>/saved_model`.
211
+
212
+ ## Evaluation on Benchmarks
213
+
214
+ Below, we provided the evaluation gin configuration files for the benchmarks we
215
+ have considered:
216
+
217
+ ```
218
+ eval/
219
+ ├── config/
220
+ │ ├── middlebury.gin
221
+ │ ├── ucf101.gin
222
+ │ ├── vimeo_90K.gin
223
+ │ ├── xiph_2K.gin
224
+ │ ├── xiph_4K.gin
225
+ ```
226
+
227
+ To run an evaluation, simply pass the configuration file of the desired evaluation dataset. <br />
228
+ If a GPU is visible, it runs on it.
229
+
230
+ ```
231
+ python3 -m eval.eval_cli \
232
+ --gin_config eval/config/<eval_dataset>.gin \
233
+ --model_path <pretrained_models>/film_net/L1/saved_model
234
+ ```
235
+
236
+ The above command will produce the PSNR and SSIM scores presented in the paper.
237
+
238
+ ## Citation
239
+
240
+ If you find this implementation useful in your works, please acknowledge it
241
+ appropriately by citing:
242
+
243
+ ```
244
+ @inproceedings{reda2022film,
245
+ title = {FILM: Frame Interpolation for Large Motion},
246
+ author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
247
+ booktitle = {European Conference on Computer Vision (ECCV)},
248
+ year = {2022}
249
+ }
250
+ ```
251
+
252
+ ```
253
+ @misc{film-tf,
254
+ title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"},
255
+ author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
256
+ year = {2022},
257
+ publisher = {GitHub},
258
+ journal = {GitHub repository},
259
+ howpublished = {\url{https://github.com/google-research/frame-interpolation}}
260
+ }
261
+ ```
262
+
263
+ ## Acknowledgments
264
+
265
+ We would like to thank Richard Tucker, Jason Lai and David Minnen. We would also
266
+ like to thank Jamie Aspinall for the imagery included in this repository.
267
+
268
+ ## Coding style
269
+
270
+ * 2 spaces for indentation
271
+ * 80 character line length
272
+ * PEP8 formatting
273
+
274
+ ## Disclaimer
275
+
276
+ This is not an officially supported Google product.
WINDOWS_INSTALLATION.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [FILM](https://github.com/google-research/frame-interpolation): Windows Installation Instructions
2
+
3
+ ## Anaconda Python 3.9 (Optional)
4
+
5
+ #### Install Anaconda3 Python3.9
6
+ * Go to [https://www.anaconda.com/products/individual](https://www.anaconda.com/products/individual) and click the "Download" button.
7
+ * Download the Windows [64-Bit](https://repo.anaconda.com/archive/Anaconda3-2021.11-Windows-x86_64.exe) or [32-bit](https://repo.anaconda.com/archive/Anaconda3-2021.11-Windows-x86.exe) Graphical Installer, depending on your system needs.
8
+ * Run the downloaded (`.exe`) file to begin the installation.
9
+ * (Optional) Check the "Add Anaconda3 to my PATH environment variable". You may get a 'red text' warning of its implications, you may ignore it for this setup.
10
+
11
+ #### Create a new Anaconda virtual environment
12
+ * Open a new Terminal
13
+ * Type the following command:
14
+ ```
15
+ conda create -n frame_interpolation pip python=3.9
16
+ ```
17
+ * The above command will create a new virtual environment with the name `frame_interpolation`
18
+
19
+ #### Activate the Anaconda virtual environment
20
+ * Activate the newly created virtual environment by typing in your terminal (Command Prompt or PowerShell)
21
+ ```
22
+ conda activate frame_interpolation
23
+ ```
24
+ * Once activated, your terminal should look like:
25
+ ```
26
+ (frame_interpolation) <present working directory> >
27
+ ```
28
+
29
+ ## NVIDIA GPU Support
30
+ #### Install CUDA Toolkit
31
+ * Go to [https://developer.nvidia.com/cuda-11.2.1-download-archive](https://developer.nvidia.com/cuda-11.2.1-download-archive) and select your `Windows`.
32
+ * Download and install `CUDA Tookit 11.2.1`.
33
+ * Additional CUDA installation information available [here](https://docs.nvidia.com/cuda/archive/11.2.2/cuda-installation-guide-microsoft-windows/index.html).
34
+
35
+ #### Install cuDNN
36
+ * Go to [https://developer.nvidia.com/rdp/cudnn-download](https://developer.nvidia.com/rdp/cudnn-download).
37
+ * Create a user profile (if needed) and login.
38
+ * Select `cuDNN v8.1.0 (January 26th, 2021), for CUDA 11.0,11.1 and 11.2`.
39
+ * Download [cuDNN Library for Widnows (x86)](https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.1.0.77/11.2_20210127/cudnn-11.2-windows-x64-v8.1.0.77.zip).
40
+ * Extract the contents of the zipped folder (it contains a folder named `cuda`) into `<INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\`. `<INSTALL_PATH>` points to the installation directory specified during CUDA Toolkit installation. By default, `<INSTAL_PATH> = C:\Program Files`.
41
+
42
+ #### Environment Setup
43
+ * Add the following paths to your 'Advanced System Settings' > 'Environment Variables ...' > Edit 'Path', and add:
44
+ * <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\bin
45
+ * <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\libnvvp
46
+ * <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\include
47
+ * <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\CUPTI\lib64
48
+ * <INSTALL_PATH>\NVIDIA GPU Computing Toolkit\CUDA\v11.2\cuda\bin
49
+
50
+ #### Verify Installation
51
+ * Open a **new** terminal and type `conda activate frame_interpolation`.
52
+ * Install (temporarily) tensorflow and run a simple operation, by typing:
53
+ ```
54
+ pip install --ignore-installed --upgrade tensorflow==2.6.0
55
+ python -c "import tensorflow as tf;print(tf.reduce_sum(tf.random.normal([1000, 1000])))"
56
+ ```
57
+ * You should see success messages: 'Created device /job:localhost/replica:0/task:0/device:GPU:0'.
58
+
59
+ ## FILM Installation
60
+ * Get Frame Interpolation source codes
61
+ ```
62
+ git clone https://github.com/google-research/frame-interpolation
63
+ cd frame-interpolation
64
+ ```
65
+ * Install dependencies
66
+ ```
67
+ pip install -r requirements.txt
68
+ conda install -c conda-forge ffmpeg
69
+ ```
70
+ * Download pre-traned models, detailed [here](https://github.com/google-research/frame-interpolation#pre-trained-models).
71
+
72
+ ## Running the Codes
73
+ * One mid-frame interpolation. Note: `python3` may not be recognized in Windows, so simply drop `3` as below.
74
+ ```
75
+ python -m eval.interpolator_test --frame1 photos\one.png --frame2 photos\two.png --model_path <pretrained_models>\film_net\Style\saved_model --output_frame photos\output_middle.png
76
+ ```
77
+
78
+ * Large resolution mid-frame interpolation: Set `block_height` and `--block_width` to subdivide along the height and width to create patches, where the interpolator will be run iteratively, and the resulting interpolated mid-patches will be reconstructed into a final mid-frame. In the example below, will create and run on 4 patches (2*2).
79
+ ```
80
+ python -m eval.interpolator_test --frame1 photos\one.png --frame2 photos\two.png --block_height 2 --block_wdith 2 --model_path <pretrained_models>\film_net\Style\saved_model --output_frame photos\output_middle.png
81
+ ```
82
+ * Many in-between frames interpolation
83
+ ```
84
+ python -m eval.interpolator_cli --pattern "photos" --model_path <pretrained_models>\film_net\Style\saved_model --times_to_interpolate 6 --output_video
85
+ ```
86
+
87
+ ## Acknowledgments
88
+
89
+ This windows installation guide is heavily based on [tensorflow-object-detection-api-tutorial](https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/install.html) .
cog.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.2"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "ipython==7.30.1"
10
+ - "tensorflow-gpu==2.8.0"
11
+ - "tensorflow-datasets==4.4.0"
12
+ - "tensorflow-addons==0.15.0"
13
+ - "absl-py==0.12.0"
14
+ - "gin-config==0.5.0"
15
+ - "parameterized==0.8.1"
16
+ - "mediapy==1.0.3"
17
+ - "scikit-image==0.19.1"
18
+ - "apache-beam==2.34.0"
19
+ run:
20
+ - apt-get update && apt-get install -y software-properties-common
21
+ - apt-get install ffmpeg -y
22
+
23
+ predict: "predict.py:Predictor"
datasets/create_middlebury_tfrecord.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Beam pipeline that generates Middlebury `Other Datasets` triplet TFRecords.
16
+
17
+ Middlebury interpolation evaluation dataset consists of two subsets.
18
+
19
+ (1) Two frames only, without the intermediate golden frame. A total of 12 such
20
+ pairs, with folder names (Army, Backyard, Basketball, Dumptruck,
21
+ Evergreen, Grove, Mequon, Schefflera, Teddy, Urban, Wooden, Yosemite)
22
+
23
+ (2) Two frames together with the intermediate golden frame. A total of 12 such
24
+ triplets, with folder names (Beanbags, Dimetrodon, DogDance, Grove2,
25
+ Grove3, Hydrangea, MiniCooper, RubberWhale, Urban2, Urban3, Venus, Walking)
26
+
27
+ This script runs on (2), i.e. the dataset with the golden frames. For more
28
+ information, visit https://vision.middlebury.edu/flow/data.
29
+
30
+ Input to the script is the root-folder that contains the unzipped folders
31
+ of input pairs (other-data) and golen frames (other-gt-interp).
32
+
33
+ Output TFRecord is a tf.train.Example proto of each image triplet.
34
+ The feature_map takes the form:
35
+ feature_map {
36
+ 'frame_0/encoded':
37
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
38
+ 'frame_0/format':
39
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
40
+ 'frame_0/height':
41
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
42
+ 'frame_0/width':
43
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
44
+ 'frame_1/encoded':
45
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
46
+ 'frame_1/format':
47
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
48
+ 'frame_1/height':
49
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
50
+ 'frame_1/width':
51
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
52
+ 'frame_2/encoded':
53
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
54
+ 'frame_2/format':
55
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
56
+ 'frame_2/height':
57
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
58
+ 'frame_2/width':
59
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
60
+ 'path':
61
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
62
+ }
63
+
64
+ Usage example:
65
+ python3 -m frame_interpolation.datasets.create_middlebury_tfrecord \
66
+ --input_dir=<root folder of middlebury-other> \
67
+ --output_tfrecord_filepath=<output tfrecord filepath>
68
+ """
69
+
70
+ import os
71
+
72
+ from . import util
73
+ from absl import app
74
+ from absl import flags
75
+ from absl import logging
76
+ import apache_beam as beam
77
+ import tensorflow as tf
78
+
79
+ _INPUT_DIR = flags.DEFINE_string(
80
+ 'input_dir',
81
+ default='/root/path/to/middlebury-other',
82
+ help='Path to the root directory of the `Other Datasets` of the Middlebury '
83
+ 'interpolation evaluation data. '
84
+ 'We expect the data to have been downloaded and unzipped. \n'
85
+ 'Folder structures:\n'
86
+ '| raw_middlebury_other_dataset/\n'
87
+ '| other-data/\n'
88
+ '| | Beanbags\n'
89
+ '| | | frame10.png\n'
90
+ '| | | frame11.png\n'
91
+ '| | Dimetrodon\n'
92
+ '| | | frame10.png\n'
93
+ '| | | frame11.png\n'
94
+ '| | ...\n'
95
+ '| other-gt-interp/\n'
96
+ '| | Beanbags\n'
97
+ '| | | frame10i11.png\n'
98
+ '| | Dimetrodon\n'
99
+ '| | | frame10i11.png\n'
100
+ '| | ...\n')
101
+
102
+ _INPUT_PAIRS_FOLDERNAME = flags.DEFINE_string(
103
+ 'input_pairs_foldername',
104
+ default='other-data',
105
+ help='Foldername containing the folders of the input frame pairs.')
106
+
107
+ _GOLDEN_FOLDERNAME = flags.DEFINE_string(
108
+ 'golden_foldername',
109
+ default='other-gt-interp',
110
+ help='Foldername containing the folders of the golden frame.')
111
+
112
+ _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
113
+ 'output_tfrecord_filepath',
114
+ default=None,
115
+ required=True,
116
+ help='Filepath to the output TFRecord file.')
117
+
118
+ _NUM_SHARDS = flags.DEFINE_integer('num_shards',
119
+ default=3,
120
+ help='Number of shards used for the output.')
121
+
122
+ # Image key -> basename for frame interpolator: start / middle / end frames.
123
+ _INTERPOLATOR_IMAGES_MAP = {
124
+ 'frame_0': 'frame10.png',
125
+ 'frame_1': 'frame10i11.png',
126
+ 'frame_2': 'frame11.png',
127
+ }
128
+
129
+
130
+ def main(unused_argv):
131
+ """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
132
+ # Collect the list of folder paths containing the input and golen frames.
133
+ pairs_list = tf.io.gfile.listdir(
134
+ os.path.join(_INPUT_DIR.value, _INPUT_PAIRS_FOLDERNAME.value))
135
+
136
+ folder_names = [
137
+ _INPUT_PAIRS_FOLDERNAME.value, _GOLDEN_FOLDERNAME.value,
138
+ _INPUT_PAIRS_FOLDERNAME.value
139
+ ]
140
+ triplet_dicts = []
141
+ for pair in pairs_list:
142
+ triplet_dict = {
143
+ image_key: os.path.join(_INPUT_DIR.value, folder, pair, image_basename)
144
+ for folder, (image_key, image_basename
145
+ ) in zip(folder_names, _INTERPOLATOR_IMAGES_MAP.items())
146
+ }
147
+ triplet_dicts.append(triplet_dict)
148
+
149
+ p = beam.Pipeline('DirectRunner')
150
+ (p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
151
+ | 'GenerateSingleExample' >> beam.ParDo(
152
+ util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
153
+ | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
154
+ file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
155
+ num_shards=_NUM_SHARDS.value,
156
+ coder=beam.coders.BytesCoder()))
157
+ result = p.run()
158
+ result.wait_until_finish()
159
+
160
+ logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
161
+ _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
162
+
163
+ if __name__ == '__main__':
164
+ app.run(main)
datasets/create_ucf101_tfrecord.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Beam pipeline that generates UCF101 `interp_test` triplet TFRecords.
16
+
17
+ UCF101 interpolation evaluation dataset consists of 379 triplets, with the
18
+ middle frame being the golden intermediate. The dataset is available here:
19
+ https://people.cs.umass.edu/~hzjiang/projects/superslomo/UCF101_results.zip.
20
+
21
+ Input to the script is the root folder that contains the unzipped
22
+ `UCF101_results` folder.
23
+
24
+ Output TFRecord is a tf.train.Example proto of each image triplet.
25
+ The feature_map takes the form:
26
+ feature_map {
27
+ 'frame_0/encoded':
28
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
29
+ 'frame_0/format':
30
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
31
+ 'frame_0/height':
32
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
33
+ 'frame_0/width':
34
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
35
+ 'frame_1/encoded':
36
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
37
+ 'frame_1/format':
38
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
39
+ 'frame_1/height':
40
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
41
+ 'frame_1/width':
42
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
43
+ 'frame_2/encoded':
44
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
45
+ 'frame_2/format':
46
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
47
+ 'frame_2/height':
48
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
49
+ 'frame_2/width':
50
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
51
+ 'path':
52
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
53
+ }
54
+
55
+ Usage example:
56
+ python3 -m frame_interpolation.datasets.create_ucf101_tfrecord \
57
+ --input_dir=<root folder of UCF101_results> \
58
+ --output_tfrecord_filepath=<output tfrecord filepath>
59
+ """
60
+
61
+ import os
62
+
63
+ from . import util
64
+ from absl import app
65
+ from absl import flags
66
+ from absl import logging
67
+ import apache_beam as beam
68
+ import tensorflow as tf
69
+
70
+ _INPUT_DIR = flags.DEFINE_string(
71
+ 'input_dir',
72
+ default='/root/path/to/UCF101_results/ucf101_interp_ours',
73
+ help='Path to the root directory of the `UCF101_results` of the UCF101 '
74
+ 'interpolation evaluation data. '
75
+ 'We expect the data to have been downloaded and unzipped. \n'
76
+ 'Folder structures:\n'
77
+ '| raw_UCF101_results/\n'
78
+ '| ucf101_interp_ours/\n'
79
+ '| | 1/\n'
80
+ '| | | frame_00.png\n'
81
+ '| | | frame_01_gt.png\n'
82
+ '| | | frame_01_ours.png\n'
83
+ '| | | frame_02.png\n'
84
+ '| | 2/\n'
85
+ '| | | frame_00.png\n'
86
+ '| | | frame_01_gt.png\n'
87
+ '| | | frame_01_ours.png\n'
88
+ '| | | frame_02.png\n'
89
+ '| | ...\n'
90
+ '| ucf101_sepconv/\n'
91
+ '| ...\n')
92
+
93
+ _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
94
+ 'output_tfrecord_filepath',
95
+ default=None,
96
+ required=True,
97
+ help='Filepath to the output TFRecord file.')
98
+
99
+ _NUM_SHARDS = flags.DEFINE_integer('num_shards',
100
+ default=2,
101
+ help='Number of shards used for the output.')
102
+
103
+ # Image key -> basename for frame interpolator: start / middle / end frames.
104
+ _INTERPOLATOR_IMAGES_MAP = {
105
+ 'frame_0': 'frame_00.png',
106
+ 'frame_1': 'frame_01_gt.png',
107
+ 'frame_2': 'frame_02.png',
108
+ }
109
+
110
+
111
+ def main(unused_argv):
112
+ """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
113
+ # Collect the list of folder paths containing the input and golden frames.
114
+ triplets_list = tf.io.gfile.listdir(_INPUT_DIR.value)
115
+
116
+ triplet_dicts = []
117
+ for triplet in triplets_list:
118
+ triplet_dicts.append({
119
+ image_key: os.path.join(_INPUT_DIR.value, triplet, image_basename)
120
+ for image_key, image_basename in _INTERPOLATOR_IMAGES_MAP.items()
121
+ })
122
+
123
+ p = beam.Pipeline('DirectRunner')
124
+ (p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
125
+ | 'GenerateSingleExample' >> beam.ParDo(
126
+ util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
127
+ | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
128
+ file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
129
+ num_shards=_NUM_SHARDS.value,
130
+ coder=beam.coders.BytesCoder()))
131
+ result = p.run()
132
+ result.wait_until_finish()
133
+
134
+ logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
135
+ _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
136
+
137
+ if __name__ == '__main__':
138
+ app.run(main)
datasets/create_vimeo90K_tfrecord.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Beam pipeline that generates Vimeo-90K (train or test) triplet TFRecords.
16
+
17
+ Vimeo-90K dataset is built upon 5,846 videos downloaded from vimeo.com. The list
18
+ of the original video links are available here:
19
+ https://github.com/anchen1011/toflow/blob/master/data/original_vimeo_links.txt.
20
+ Each video is further cropped into a fixed spatial size of (448 x 256) to create
21
+ 89,000 video clips.
22
+
23
+ The Vimeo-90K dataset is designed for four video processing tasks. This script
24
+ creates the TFRecords of frame triplets for frame interpolation task.
25
+
26
+ Temporal frame interpolation triplet dataset:
27
+ - 73,171 triplets of size (448x256) extracted from 15K subsets of Vimeo-90K.
28
+ - The triplets are pre-split into (train,test) = (51313,3782)
29
+ - Download links:
30
+ Test-set: http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip
31
+ Train+test-set: http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip
32
+
33
+ For more information, see the arXiv paper, project page or the GitHub link.
34
+ @article{xue17toflow,
35
+ author = {Xue, Tianfan and
36
+ Chen, Baian and
37
+ Wu, Jiajun and
38
+ Wei, Donglai and
39
+ Freeman, William T},
40
+ title = {Video Enhancement with Task-Oriented Flow},
41
+ journal = {arXiv},
42
+ year = {2017}
43
+ }
44
+ Project: http://toflow.csail.mit.edu/
45
+ GitHub: https://github.com/anchen1011/toflow
46
+
47
+ Inputs to the script are (1) the directory to the downloaded and unzipped folder
48
+ (2) the filepath of the text-file that lists the subfolders of the triplets.
49
+
50
+ Output TFRecord is a tf.train.Example proto of each image triplet.
51
+ The feature_map takes the form:
52
+ feature_map {
53
+ 'frame_0/encoded':
54
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
55
+ 'frame_0/format':
56
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
57
+ 'frame_0/height':
58
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
59
+ 'frame_0/width':
60
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
61
+ 'frame_1/encoded':
62
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
63
+ 'frame_1/format':
64
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
65
+ 'frame_1/height':
66
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
67
+ 'frame_1/width':
68
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
69
+ 'frame_2/encoded':
70
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
71
+ 'frame_2/format':
72
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
73
+ 'frame_2/height':
74
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
75
+ 'frame_2/width':
76
+ tf.io.FixedLenFeature((), tf.int64, default_value=0)
77
+ 'path':
78
+ tf.io.FixedLenFeature((), tf.string, default_value='')
79
+ }
80
+
81
+ Usage example:
82
+ python3 -m frame_interpolation.datasets.create_vimeo90K_tfrecord \
83
+ --input_dir=<root folder of vimeo90K dataset> \
84
+ --input_triplet_list_filepath=<filepath of tri_{test|train}list.txt> \
85
+ --output_tfrecord_filepath=<output tfrecord filepath>
86
+ """
87
+ import os
88
+
89
+ from . import util
90
+ from absl import app
91
+ from absl import flags
92
+ from absl import logging
93
+ import apache_beam as beam
94
+ import numpy as np
95
+ import tensorflow as tf
96
+
97
+
98
+ _INPUT_DIR = flags.DEFINE_string(
99
+ 'input_dir',
100
+ default='/path/to/raw_vimeo_interp/sequences',
101
+ help='Path to the root directory of the vimeo frame interpolation dataset. '
102
+ 'We expect the data to have been downloaded and unzipped.\n'
103
+ 'Folder structures:\n'
104
+ '| raw_vimeo_dataset/\n'
105
+ '| sequences/\n'
106
+ '| | 00001\n'
107
+ '| | | 0389/\n'
108
+ '| | | | im1.png\n'
109
+ '| | | | im2.png\n'
110
+ '| | | | im3.png\n'
111
+ '| | | ...\n'
112
+ '| | 00002/\n'
113
+ '| | ...\n'
114
+ '| readme.txt\n'
115
+ '| tri_trainlist.txt\n'
116
+ '| tri_testlist.txt \n')
117
+
118
+ _INTPUT_TRIPLET_LIST_FILEPATH = flags.DEFINE_string(
119
+ 'input_triplet_list_filepath',
120
+ default='/path/to/raw_vimeo_dataset/tri_{test|train}list.txt',
121
+ help='Text file containing a list of sub-directories of input triplets.')
122
+
123
+ _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
124
+ 'output_tfrecord_filepath',
125
+ default=None,
126
+ help='Filepath to the output TFRecord file.')
127
+
128
+ _NUM_SHARDS = flags.DEFINE_integer('num_shards',
129
+ default=200, # set to 3 for vimeo_test, and 200 for vimeo_train.
130
+ help='Number of shards used for the output.')
131
+
132
+ # Image key -> basename for frame interpolator: start / middle / end frames.
133
+ _INTERPOLATOR_IMAGES_MAP = {
134
+ 'frame_0': 'im1.png',
135
+ 'frame_1': 'im2.png',
136
+ 'frame_2': 'im3.png',
137
+ }
138
+
139
+
140
+ def main(unused_argv):
141
+ """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
142
+ with tf.io.gfile.GFile(_INTPUT_TRIPLET_LIST_FILEPATH.value, 'r') as fid:
143
+ triplets_list = np.loadtxt(fid, dtype=str)
144
+
145
+ triplet_dicts = []
146
+ for triplet in triplets_list:
147
+ triplet_dict = {
148
+ image_key: os.path.join(_INPUT_DIR.value, triplet, image_basename)
149
+ for image_key, image_basename in _INTERPOLATOR_IMAGES_MAP.items()
150
+ }
151
+ triplet_dicts.append(triplet_dict)
152
+ p = beam.Pipeline('DirectRunner')
153
+ (p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
154
+ | 'GenerateSingleExample' >> beam.ParDo(
155
+ util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
156
+ | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
157
+ file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
158
+ num_shards=_NUM_SHARDS.value,
159
+ coder=beam.coders.BytesCoder()))
160
+ result = p.run()
161
+ result.wait_until_finish()
162
+
163
+ logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
164
+ _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
165
+
166
+ if __name__ == '__main__':
167
+ app.run(main)
datasets/create_xiph_tfrecord.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Beam pipeline that generates Xiph triplet TFRecords.
16
+
17
+ Xiph is a frame sequence dataset commonly used to assess video compression. See
18
+ here: https://media.xiph.org/video/derf/
19
+
20
+ The SoftSplat paper selected eight 4K clips with the most amount of motion and
21
+ extracted the first 100 frames from each clip. Each frame is then either resized
22
+ from 4K to 2K, or a 2K center crop from them is performed before interpolating
23
+ the even frames from the odd frames. These datasets are denoted as `Xiph-2K`
24
+ and `Xiph-4K` respectively. For more information see the project page:
25
+ https://github.com/sniklaus/softmax-splatting
26
+
27
+ Input is the root folder that contains the 800 frames of the eight clips. Set
28
+ center_crop_factor=2 and scale_factor=1 to generate `Xiph-4K`,and scale_factor=2
29
+ , center_crop_factor=1 to generate `Xiph-2K`. The scripts defaults to `Xiph-2K`.
30
+
31
+ Output TFRecord is a tf.train.Example proto of each image triplet.
32
+ The feature_map takes the form:
33
+ feature_map {
34
+ 'frame_0/encoded':
35
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
36
+ 'frame_0/format':
37
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
38
+ 'frame_0/height':
39
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
40
+ 'frame_0/width':
41
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
42
+ 'frame_1/encoded':
43
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
44
+ 'frame_1/format':
45
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
46
+ 'frame_1/height':
47
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
48
+ 'frame_1/width':
49
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
50
+ 'frame_2/encoded':
51
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
52
+ 'frame_2/format':
53
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
54
+ 'frame_2/height':
55
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
56
+ 'frame_2/width':
57
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
58
+ 'path':
59
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
60
+ }
61
+
62
+ Usage example:
63
+ python3 -m frame_interpolation.datasets.create_xiph_tfrecord \
64
+ --input_dir=<root folder of xiph dataset> \
65
+ --scale_factor=<scale factor for image resizing, default=2> \
66
+ --center_crop_factor=<center cropping factor, default=1> \
67
+ --output_tfrecord_filepath=<output tfrecord filepath>
68
+ """
69
+ import os
70
+
71
+ from . import util
72
+ from absl import app
73
+ from absl import flags
74
+ from absl import logging
75
+ import apache_beam as beam
76
+ import tensorflow as tf
77
+
78
+ _INPUT_DIR = flags.DEFINE_string(
79
+ 'input_dir',
80
+ default='/root/path/to/selected/xiph/clips',
81
+ help='Path to the root directory of the `Xiph` interpolation evaluation '
82
+ 'data. We expect the data to have been downloaded and unzipped.')
83
+ _CENTER_CROP_FACTOR = flags.DEFINE_integer(
84
+ 'center_crop_factor',
85
+ default=1,
86
+ help='Factor to center crop image. If set to 2, an image of the same '
87
+ 'resolution as the inputs but half the size is created.')
88
+ _SCALE_FACTOR = flags.DEFINE_integer(
89
+ 'scale_factor',
90
+ default=2,
91
+ help='Factor to downsample frames.')
92
+ _NUM_CLIPS = flags.DEFINE_integer(
93
+ 'num_clips', default=8, help='Number of clips.')
94
+ _NUM_FRAMES = flags.DEFINE_integer(
95
+ 'num_frames', default=100, help='Number of frames per clip.')
96
+ _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
97
+ 'output_tfrecord_filepath',
98
+ default=None,
99
+ required=True,
100
+ help='Filepath to the output TFRecord file.')
101
+ _NUM_SHARDS = flags.DEFINE_integer('num_shards',
102
+ default=2,
103
+ help='Number of shards used for the output.')
104
+
105
+ # Image key -> offset for frame interpolator: start / middle / end frame offset.
106
+ _INTERPOLATOR_IMAGES_MAP = {
107
+ 'frame_0': -1,
108
+ 'frame_1': 0,
109
+ 'frame_2': 1,
110
+ }
111
+
112
+
113
+ def main(unused_argv):
114
+ """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
115
+ # Collect the list of frame filenames.
116
+ frames_list = sorted(tf.io.gfile.listdir(_INPUT_DIR.value))
117
+
118
+ # Collect the triplets, even frames serving as golden to interpolate odds.
119
+ triplets_dict = []
120
+ for clip_index in range(_NUM_CLIPS.value):
121
+ for frame_index in range(1, _NUM_FRAMES.value - 1, 2):
122
+ index = clip_index * _NUM_FRAMES.value + frame_index
123
+ triplet_dict = {
124
+ image_key: os.path.join(_INPUT_DIR.value,
125
+ frames_list[index + image_offset])
126
+ for image_key, image_offset in _INTERPOLATOR_IMAGES_MAP.items()
127
+ }
128
+ triplets_dict.append(triplet_dict)
129
+
130
+ p = beam.Pipeline('DirectRunner')
131
+ (p | 'ReadInputTripletDicts' >> beam.Create(triplets_dict) # pylint: disable=expression-not-assigned
132
+ | 'GenerateSingleExample' >> beam.ParDo(
133
+ util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP, _SCALE_FACTOR.value,
134
+ _CENTER_CROP_FACTOR.value))
135
+ | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
136
+ file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
137
+ num_shards=_NUM_SHARDS.value,
138
+ coder=beam.coders.BytesCoder()))
139
+ result = p.run()
140
+ result.wait_until_finish()
141
+
142
+ logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
143
+ _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
144
+
145
+ if __name__ == '__main__':
146
+ app.run(main)
datasets/util.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for creating a tf.train.Example proto of image triplets."""
16
+
17
+ import io
18
+ import os
19
+ from typing import Any, List, Mapping, Optional
20
+
21
+ from absl import logging
22
+ import apache_beam as beam
23
+ import numpy as np
24
+ import PIL.Image
25
+ import six
26
+ from skimage import transform
27
+ import tensorflow as tf
28
+
29
+ _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
30
+ _GAMMA = 2.2
31
+
32
+
33
+ def _resample_image(image: np.ndarray, resample_image_width: int,
34
+ resample_image_height: int) -> np.ndarray:
35
+ """Re-samples and returns an `image` to be `resample_image_size`."""
36
+ # Convert image from uint8 gamma [0..255] to float linear [0..1].
37
+ image = image.astype(np.float32) / _UINT8_MAX_F
38
+ image = np.power(np.clip(image, 0, 1), _GAMMA)
39
+
40
+ # Re-size the image
41
+ resample_image_size = (resample_image_height, resample_image_width)
42
+ image = transform.resize_local_mean(image, resample_image_size)
43
+
44
+ # Convert back from float linear [0..1] to uint8 gamma [0..255].
45
+ image = np.power(np.clip(image, 0, 1), 1.0 / _GAMMA)
46
+ image = np.clip(image * _UINT8_MAX_F + 0.5, 0.0,
47
+ _UINT8_MAX_F).astype(np.uint8)
48
+ return image
49
+
50
+
51
+ def generate_image_triplet_example(
52
+ triplet_dict: Mapping[str, str],
53
+ scale_factor: int = 1,
54
+ center_crop_factor: int = 1) -> Optional[tf.train.Example]:
55
+ """Generates and serializes a tf.train.Example proto from an image triplet.
56
+
57
+ Default setting creates a triplet Example with the input images unchanged.
58
+ Images are processed in the order of center-crop then downscale.
59
+
60
+ Args:
61
+ triplet_dict: A dict of image key to filepath of the triplet images.
62
+ scale_factor: An integer scale factor to isotropically downsample images.
63
+ center_crop_factor: An integer cropping factor to center crop images with
64
+ the original resolution but isotropically downsized by the factor.
65
+
66
+ Returns:
67
+ tf.train.Example proto, or None upon error.
68
+
69
+ Raises:
70
+ ValueError if triplet_dict length is different from three or the scale input
71
+ arguments are non-positive.
72
+ """
73
+ if len(triplet_dict) != 3:
74
+ raise ValueError(
75
+ f'Length of triplet_dict must be exactly 3, not {len(triplet_dict)}.')
76
+
77
+ if scale_factor <= 0 or center_crop_factor <= 0:
78
+ raise ValueError(f'(scale_factor, center_crop_factor) must be positive, '
79
+ f'Not ({scale_factor}, {center_crop_factor}).')
80
+
81
+ feature = {}
82
+
83
+ # Keep track of the path where the images came from for debugging purposes.
84
+ mid_frame_path = os.path.dirname(triplet_dict['frame_1'])
85
+ feature['path'] = tf.train.Feature(
86
+ bytes_list=tf.train.BytesList(value=[six.ensure_binary(mid_frame_path)]))
87
+
88
+ for image_key, image_path in triplet_dict.items():
89
+ if not tf.io.gfile.exists(image_path):
90
+ logging.error('File not found: %s', image_path)
91
+ return None
92
+
93
+ # Note: we need both the raw bytes and the image size.
94
+ # PIL.Image does not expose a method to grab the original bytes.
95
+ # (Also it is not aware of non-local file systems.)
96
+ # So we read with tf.io.gfile.GFile to get the bytes, and then wrap the
97
+ # bytes in BytesIO to let PIL.Image open the image.
98
+ try:
99
+ byte_array = tf.io.gfile.GFile(image_path, 'rb').read()
100
+ except tf.errors.InvalidArgumentError:
101
+ logging.exception('Cannot read image file: %s', image_path)
102
+ return None
103
+ try:
104
+ pil_image = PIL.Image.open(io.BytesIO(byte_array))
105
+ except PIL.UnidentifiedImageError:
106
+ logging.exception('Cannot decode image file: %s', image_path)
107
+ return None
108
+ width, height = pil_image.size
109
+ pil_image_format = pil_image.format
110
+
111
+ # Optionally center-crop images and downsize images
112
+ # by `center_crop_factor`.
113
+ if center_crop_factor > 1:
114
+ image = np.array(pil_image)
115
+ quarter_height = image.shape[0] // (2 * center_crop_factor)
116
+ quarter_width = image.shape[1] // (2 * center_crop_factor)
117
+ image = image[quarter_height:-quarter_height,
118
+ quarter_width:-quarter_width, :]
119
+ pil_image = PIL.Image.fromarray(image)
120
+
121
+ # Update image properties.
122
+ height, width, _ = image.shape
123
+ buffer = io.BytesIO()
124
+ try:
125
+ pil_image.save(buffer, format='PNG')
126
+ except OSError:
127
+ logging.exception('Cannot encode image file: %s', image_path)
128
+ return None
129
+ byte_array = buffer.getvalue()
130
+
131
+ # Optionally downsample images by `scale_factor`.
132
+ if scale_factor > 1:
133
+ image = np.array(pil_image)
134
+ image = _resample_image(image, image.shape[1] // scale_factor,
135
+ image.shape[0] // scale_factor)
136
+ pil_image = PIL.Image.fromarray(image)
137
+
138
+ # Update image properties.
139
+ height, width, _ = image.shape
140
+ buffer = io.BytesIO()
141
+ try:
142
+ pil_image.save(buffer, format='PNG')
143
+ except OSError:
144
+ logging.exception('Cannot encode image file: %s', image_path)
145
+ return None
146
+ byte_array = buffer.getvalue()
147
+
148
+ # Create tf Features.
149
+ image_feature = tf.train.Feature(
150
+ bytes_list=tf.train.BytesList(value=[byte_array]))
151
+ height_feature = tf.train.Feature(
152
+ int64_list=tf.train.Int64List(value=[height]))
153
+ width_feature = tf.train.Feature(
154
+ int64_list=tf.train.Int64List(value=[width]))
155
+ encoding = tf.train.Feature(
156
+ bytes_list=tf.train.BytesList(
157
+ value=[six.ensure_binary(pil_image_format.lower())]))
158
+
159
+ # Update feature map.
160
+ feature[f'{image_key}/encoded'] = image_feature
161
+ feature[f'{image_key}/format'] = encoding
162
+ feature[f'{image_key}/height'] = height_feature
163
+ feature[f'{image_key}/width'] = width_feature
164
+
165
+ # Create tf Example.
166
+ features = tf.train.Features(feature=feature)
167
+ example = tf.train.Example(features=features)
168
+ return example
169
+
170
+
171
+ class ExampleGenerator(beam.DoFn):
172
+ """Generate a tf.train.Example per input image triplet filepaths."""
173
+
174
+ def __init__(self,
175
+ images_map: Mapping[str, Any],
176
+ scale_factor: int = 1,
177
+ center_crop_factor: int = 1):
178
+ """Initializes the map of 3 images to add to each tf.train.Example.
179
+
180
+ Args:
181
+ images_map: Map from image key to image filepath.
182
+ scale_factor: A scale factor to downsample frames.
183
+ center_crop_factor: A factor to centercrop and downsize frames.
184
+ """
185
+ super().__init__()
186
+ self._images_map = images_map
187
+ self._scale_factor = scale_factor
188
+ self._center_crop_factor = center_crop_factor
189
+
190
+ def process(self, triplet_dict: Mapping[str, str]) -> List[bytes]:
191
+ """Generates a serialized tf.train.Example for a triplet of images.
192
+
193
+ Args:
194
+ triplet_dict: A dict of image key to filepath of the triplet images.
195
+
196
+ Returns:
197
+ A serialized tf.train.Example proto. No shuffling is applied.
198
+ """
199
+ example = generate_image_triplet_example(triplet_dict, self._scale_factor,
200
+ self._center_crop_factor)
201
+ if example:
202
+ return [example.SerializeToString()]
203
+ else:
204
+ return []
eval/.DS_Store ADDED
Binary file (6.15 kB). View file
 
eval/config/middlebury.gin ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ experiment.name = 'middlebury'
16
+ evaluation.max_examples = -1
17
+ evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18
+ evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3'
eval/config/ucf101.gin ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ experiment.name = 'ucf101'
16
+ evaluation.max_examples = -1
17
+ evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18
+ evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2'
eval/config/vimeo_90K.gin ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ experiment.name = 'vimeo_90K'
16
+ evaluation.max_examples = -1
17
+ evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18
+ evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3'
eval/config/xiph_2K.gin ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ experiment.name = 'xiph_2K'
16
+ evaluation.max_examples = -1
17
+ evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18
+ evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2'
eval/config/xiph_4K.gin ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ experiment.name = 'xiph_4K'
16
+ evaluation.max_examples = -1
17
+ evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18
+ evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2'
eval/eval_cli.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Evaluate the frame interpolation model from a tfrecord and store results.
16
+
17
+ This script runs the inference on examples in a tfrecord and generates images
18
+ and numeric results according to the gin config. For details, see the
19
+ run_evaluation() function below.
20
+
21
+ Usage example:
22
+ python3 -m frame_interpolation.eval.eval_cli -- \
23
+ --gin_config <path to eval_dataset.gin> \
24
+ --base_folder <the root directory to all training sessions> \
25
+ --label < the foldername of the training session>
26
+
27
+ or
28
+
29
+ python3 -m frame_interpolation.eval.eval_cli -- \
30
+ --gin_config <path to eval_dataset.gin> \
31
+ --model_path <The filepath of the TF2 saved model>
32
+
33
+ The output is saved at the parent directory of the `model_path`:
34
+ <parent directory of model_path>/batch_eval.
35
+
36
+ The evaluation is run on a GPU by default. Add the `--mode` argument for others.
37
+ """
38
+ import collections
39
+ import os
40
+ from typing import Any, Dict
41
+
42
+ from . import util
43
+ from absl import app
44
+ from absl import flags
45
+ from absl import logging
46
+ import gin.tf
47
+ from ..losses import losses
48
+ import numpy as np
49
+ import tensorflow as tf
50
+ from ..training import data_lib
51
+
52
+
53
+ _GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.')
54
+ _LABEL = flags.DEFINE_string(
55
+ 'label', None, 'Descriptive label for the training session to eval.')
56
+ _BASE_FOLDER = flags.DEFINE_string('base_folder', None,
57
+ 'Root folder of training sessions.')
58
+ _MODEL_PATH = flags.DEFINE_string(
59
+ name='model_path',
60
+ default=None,
61
+ help='The path of the TF2 saved model to use. If _MODEL_PATH argument is '
62
+ 'directly specified, _LABEL and _BASE_FOLDER arguments will be ignored.')
63
+ _OUTPUT_FRAMES = flags.DEFINE_boolean(
64
+ name='output_frames',
65
+ default=False,
66
+ help='If true, saves the the inputs, groud-truth and interpolated frames.')
67
+ _MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'],
68
+ 'Device to run evaluations.')
69
+
70
+
71
+ @gin.configurable('experiment')
72
+ def _get_experiment_config(name) -> Dict[str, Any]:
73
+ """Fetches the gin config."""
74
+ return {
75
+ 'name': name,
76
+ }
77
+
78
+
79
+ def _set_visible_devices():
80
+ """Set the visible devices according to running mode."""
81
+ mode_devices = tf.config.list_physical_devices(_MODE.value.upper())
82
+ tf.config.set_visible_devices([], 'GPU')
83
+ tf.config.set_visible_devices([], 'TPU')
84
+ tf.config.set_visible_devices(mode_devices, _MODE.value.upper())
85
+ return
86
+
87
+
88
+ @gin.configurable('evaluation')
89
+ def run_evaluation(model_path, tfrecord, output_dir, max_examples, metrics):
90
+ """Runs the eval loop for examples in the tfrecord.
91
+
92
+ The evaluation is run for the first 'max_examples' number of examples, and
93
+ resulting images are stored into the given output_dir. Any tensor that
94
+ appears like an image is stored with its name -- this may include intermediate
95
+ results, depending on what the model outputs.
96
+
97
+ Additionally, numeric results are stored into results.csv file within the same
98
+ directory. This includes per-example metrics and the mean across the whole
99
+ dataset.
100
+
101
+ Args:
102
+ model_path: Directory TF2 saved model.
103
+ tfrecord: Directory to the tfrecord eval data.
104
+ output_dir: Directory to store the results into.
105
+ max_examples: Maximum examples to evaluate.
106
+ metrics: The names of loss functions to use.
107
+ """
108
+ model = tf.saved_model.load(model_path)
109
+
110
+ # Store a 'readme.txt' that contains information on where the data came from.
111
+ with tf.io.gfile.GFile(os.path.join(output_dir, 'readme.txt'), mode='w') as f:
112
+ print('Results for:', file=f)
113
+ print(f' model: {model_path}', file=f)
114
+ print(f' tfrecord: {tfrecord}', file=f)
115
+
116
+ with tf.io.gfile.GFile(
117
+ os.path.join(output_dir, 'results.csv'), mode='w') as csv_file:
118
+ test_losses = losses.test_losses(metrics, [
119
+ 1.0,
120
+ ] * len(metrics))
121
+ title_row = ['key'] + list(test_losses)
122
+ print(', '.join(title_row), file=csv_file)
123
+
124
+ datasets = data_lib.create_eval_datasets(
125
+ batch_size=1,
126
+ files=[tfrecord],
127
+ names=[os.path.basename(output_dir)],
128
+ max_examples=max_examples)
129
+ dataset = datasets[os.path.basename(output_dir)]
130
+
131
+ all_losses = collections.defaultdict(list)
132
+ for example in dataset:
133
+ inputs = {
134
+ 'x0': example['x0'],
135
+ 'x1': example['x1'],
136
+ 'time': example['time'][..., tf.newaxis],
137
+ }
138
+ prediction = model(inputs, training=False)
139
+
140
+ # Get the key from encoded mid-frame path.
141
+ path = example['path'][0].numpy().decode('utf-8')
142
+ key = path.rsplit('.', 1)[0].rsplit(os.sep)[-1]
143
+
144
+ # Combines both inputs and outputs into a single dictionary:
145
+ combined = {**prediction, **example} if _OUTPUT_FRAMES.value else {}
146
+ for name in combined:
147
+ image = combined[name]
148
+ if isinstance(image, tf.Tensor):
149
+ # This saves any tensor that has a shape that can be interpreted
150
+ # as an image, e.g. (1, H, W, C), where the batch dimension is always
151
+ # 1, H and W are the image height and width, and C is either 1 or 3
152
+ # (grayscale or color image).
153
+ if len(image.shape) == 4 and (image.shape[-1] == 1 or
154
+ image.shape[-1] == 3):
155
+ util.write_image(
156
+ os.path.join(output_dir, f'{key}_{name}.png'), image[0].numpy())
157
+
158
+ # Evaluate losses if the dataset has ground truth 'y', otherwise just do
159
+ # a visual eval.
160
+ if 'y' in example:
161
+ loss_values = []
162
+ # Clip interpolator output to the range [0,1]. Clipping is done only
163
+ # on the eval loop to get better metrics, but not on the training loop
164
+ # so gradients are not killed.
165
+ prediction['image'] = tf.clip_by_value(prediction['image'], 0., 1.)
166
+ for loss_name, (loss_value_fn, loss_weight_fn) in test_losses.items():
167
+ loss_value = loss_value_fn(example, prediction) * loss_weight_fn(0)
168
+ loss_values.append(loss_value.numpy())
169
+ all_losses[loss_name].append(loss_value.numpy())
170
+ print(f'{key}, {str(loss_values)[1:-1]}', file=csv_file)
171
+
172
+ if all_losses:
173
+ totals = [np.mean(all_losses[loss_name]) for loss_name in test_losses]
174
+ print(f'mean, {str(totals)[1:-1]}', file=csv_file)
175
+ totals_dict = {
176
+ loss_name: np.mean(all_losses[loss_name]) for loss_name in test_losses
177
+ }
178
+ logging.info('mean, %s', totals_dict)
179
+
180
+
181
+ def main(argv):
182
+ if len(argv) > 1:
183
+ raise app.UsageError('Too many command-line arguments.')
184
+
185
+ if _MODEL_PATH.value is not None:
186
+ model_path = _MODEL_PATH.value
187
+ else:
188
+ model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'saved_model')
189
+
190
+ gin.parse_config_files_and_bindings(
191
+ config_files=[_GIN_CONFIG.value],
192
+ bindings=None,
193
+ skip_unknown=True)
194
+
195
+ config = _get_experiment_config() # pylint: disable=no-value-for-parameter
196
+ eval_name = config['name']
197
+ output_dir = os.path.join(
198
+ os.path.dirname(model_path), 'batch_eval', eval_name)
199
+ logging.info('Creating output_dir @ %s ...', output_dir)
200
+
201
+ # Copy config file to <base_folder>/<label>/batch_eval/<eval_name>/config.gin.
202
+ tf.io.gfile.makedirs(output_dir)
203
+ tf.io.gfile.copy(
204
+ _GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True)
205
+
206
+ _set_visible_devices()
207
+ logging.info('Evaluating %s on %s ...', eval_name, [
208
+ el.name.split('/physical_device:')[-1]
209
+ for el in tf.config.get_visible_devices()
210
+ ])
211
+ run_evaluation(model_path=model_path, output_dir=output_dir) # pylint: disable=no-value-for-parameter
212
+
213
+ logging.info('Done. Evaluations saved @ %s.', output_dir)
214
+
215
+ if __name__ == '__main__':
216
+ app.run(main)
eval/interpolator.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """A wrapper class for running a frame interpolation TF2 saved model.
16
+
17
+ Usage:
18
+ model_path='/tmp/saved_model/'
19
+ it = Interpolator(model_path)
20
+ result_batch = it.interpolate(image_batch_0, image_batch_1, batch_dt)
21
+
22
+ Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
23
+ (B,H,W,C) layout, batch_dt is the sub-frame time in range [0,1], (B,) layout.
24
+ """
25
+ from typing import List, Optional
26
+ import numpy as np
27
+ import tensorflow as tf
28
+
29
+
30
+ def _pad_to_align(x, align):
31
+ """Pad image batch x so width and height divide by align.
32
+
33
+ Args:
34
+ x: Image batch to align.
35
+ align: Number to align to.
36
+
37
+ Returns:
38
+ 1) An image padded so width % align == 0 and height % align == 0.
39
+ 2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
40
+ to undo the padding.
41
+ """
42
+ # Input checking.
43
+ assert np.ndim(x) == 4
44
+ assert align > 0, 'align must be a positive number.'
45
+
46
+ height, width = x.shape[-3:-1]
47
+ height_to_pad = (align - height % align) if height % align != 0 else 0
48
+ width_to_pad = (align - width % align) if width % align != 0 else 0
49
+
50
+ bbox_to_pad = {
51
+ 'offset_height': height_to_pad // 2,
52
+ 'offset_width': width_to_pad // 2,
53
+ 'target_height': height + height_to_pad,
54
+ 'target_width': width + width_to_pad
55
+ }
56
+ padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
57
+ bbox_to_crop = {
58
+ 'offset_height': height_to_pad // 2,
59
+ 'offset_width': width_to_pad // 2,
60
+ 'target_height': height,
61
+ 'target_width': width
62
+ }
63
+ return padded_x, bbox_to_crop
64
+
65
+
66
+ def image_to_patches(image: np.ndarray, block_shape: List[int]) -> np.ndarray:
67
+ """Folds an image into patches and stacks along the batch dimension.
68
+
69
+ Args:
70
+ image: The input image of shape [B, H, W, C].
71
+ block_shape: The number of patches along the height and width to extract.
72
+ Each patch is shaped (H/block_shape[0], W/block_shape[1])
73
+
74
+ Returns:
75
+ The extracted patches shaped [num_blocks, patch_height, patch_width,...],
76
+ with num_blocks = block_shape[0] * block_shape[1].
77
+ """
78
+ block_height, block_width = block_shape
79
+ num_blocks = block_height * block_width
80
+
81
+ height, width, channel = image.shape[-3:]
82
+ patch_height, patch_width = height//block_height, width//block_width
83
+
84
+ assert height == (
85
+ patch_height * block_height
86
+ ), 'block_height=%d should evenly divide height=%d.'%(block_height, height)
87
+ assert width == (
88
+ patch_width * block_width
89
+ ), 'block_width=%d should evenly divide width=%d.'%(block_width, width)
90
+
91
+ patch_size = patch_height * patch_width
92
+ paddings = 2*[[0, 0]]
93
+
94
+ patches = tf.space_to_batch(image, [patch_height, patch_width], paddings)
95
+ patches = tf.split(patches, patch_size, 0)
96
+ patches = tf.stack(patches, axis=3)
97
+ patches = tf.reshape(patches,
98
+ [num_blocks, patch_height, patch_width, channel])
99
+ return patches.numpy()
100
+
101
+
102
+ def patches_to_image(patches: np.ndarray, block_shape: List[int]) -> np.ndarray:
103
+ """Unfolds patches (stacked along batch) into an image.
104
+
105
+ Args:
106
+ patches: The input patches, shaped [num_patches, patch_H, patch_W, C].
107
+ block_shape: The number of patches along the height and width to unfold.
108
+ Each patch assumed to be shaped (H/block_shape[0], W/block_shape[1]).
109
+
110
+ Returns:
111
+ The unfolded image shaped [B, H, W, C].
112
+ """
113
+ block_height, block_width = block_shape
114
+ paddings = 2 * [[0, 0]]
115
+
116
+ patch_height, patch_width, channel = patches.shape[-3:]
117
+ patch_size = patch_height * patch_width
118
+
119
+ patches = tf.reshape(patches,
120
+ [1, block_height, block_width, patch_size, channel])
121
+ patches = tf.split(patches, patch_size, axis=3)
122
+ patches = tf.stack(patches, axis=0)
123
+ patches = tf.reshape(patches,
124
+ [patch_size, block_height, block_width, channel])
125
+ image = tf.batch_to_space(patches, [patch_height, patch_width], paddings)
126
+ return image.numpy()
127
+
128
+
129
+ class Interpolator:
130
+ """A class for generating interpolated frames between two input frames.
131
+
132
+ Uses TF2 saved model format.
133
+ """
134
+
135
+ def __init__(self, model_path: str,
136
+ align: Optional[int] = None,
137
+ block_shape: Optional[List[int]] = None) -> None:
138
+ """Loads a saved model.
139
+
140
+ Args:
141
+ model_path: Path to the saved model. If none are provided, uses the
142
+ default model.
143
+ align: 'If >1, pad the input size so it divides with this before
144
+ inference.'
145
+ block_shape: Number of patches along the (height, width) to sid-divide
146
+ input images.
147
+ """
148
+ self._model = tf.compat.v2.saved_model.load(model_path)
149
+ self._align = align or None
150
+ self._block_shape = block_shape or None
151
+
152
+ def interpolate(self, x0: np.ndarray, x1: np.ndarray,
153
+ dt: np.ndarray) -> np.ndarray:
154
+ """Generates an interpolated frame between given two batches of frames.
155
+
156
+ All input tensors should be np.float32 datatype.
157
+
158
+ Args:
159
+ x0: First image batch. Dimensions: (batch_size, height, width, channels)
160
+ x1: Second image batch. Dimensions: (batch_size, height, width, channels)
161
+ dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
162
+
163
+ Returns:
164
+ The result with dimensions (batch_size, height, width, channels).
165
+ """
166
+ if self._align is not None:
167
+ x0, bbox_to_crop = _pad_to_align(x0, self._align)
168
+ x1, _ = _pad_to_align(x1, self._align)
169
+
170
+ inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
171
+ result = self._model(inputs, training=False)
172
+ image = result['image']
173
+
174
+ if self._align is not None:
175
+ image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
176
+ return image.numpy()
177
+
178
+ def __call__(self, x0: np.ndarray, x1: np.ndarray,
179
+ dt: np.ndarray) -> np.ndarray:
180
+ """Generates an interpolated frame between given two batches of frames.
181
+
182
+ All input tensors should be np.float32 datatype.
183
+
184
+ Args:
185
+ x0: First image batch. Dimensions: (batch_size, height, width, channels)
186
+ x1: Second image batch. Dimensions: (batch_size, height, width, channels)
187
+ dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
188
+
189
+ Returns:
190
+ The result with dimensions (batch_size, height, width, channels).
191
+ """
192
+ if self._block_shape is not None and np.prod(self._block_shape) > 1:
193
+ # Subdivide high-res images into managable non-overlapping patches.
194
+ x0_patches = image_to_patches(x0, self._block_shape)
195
+ x1_patches = image_to_patches(x1, self._block_shape)
196
+
197
+ # Run the interpolator on each patch pair.
198
+ output_patches = []
199
+ for image_0, image_1 in zip(x0_patches, x1_patches):
200
+ mid_patch = self.interpolate(image_0[np.newaxis, ...],
201
+ image_1[np.newaxis, ...], dt)
202
+ output_patches.append(mid_patch)
203
+
204
+ # Reconstruct interpolated image by stitching interpolated patches.
205
+ output_patches = np.concatenate(output_patches, axis=0)
206
+ return patches_to_image(output_patches, self._block_shape)
207
+ else:
208
+ # Invoke the interpolator once.
209
+ return self.interpolate(x0, x1, dt)
eval/interpolator_cli.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Runs the FILM frame interpolator on a pair of frames on beam.
16
+
17
+ This script is used evaluate the output quality of the FILM Tensorflow frame
18
+ interpolator. Optionally, it outputs a video of the interpolated frames.
19
+
20
+ A beam pipeline for invoking the frame interpolator on a set of directories
21
+ identified by a glob (--pattern). Each directory is expected to contain two
22
+ input frames that are the inputs to the frame interpolator. If a directory has
23
+ more than two frames, then each contiguous frame pair is treated as input to
24
+ generate in-between frames.
25
+
26
+ The output video is stored to interpolator.mp4 in each directory. The number of
27
+ frames is determined by --times_to_interpolate, which controls the number of
28
+ times the frame interpolator is invoked. When the number of input frames is 2,
29
+ the number of output frames is 2^times_to_interpolate+1.
30
+
31
+ This expects a directory structure such as:
32
+ <root directory of the eval>/01/frame1.png
33
+ frame2.png
34
+ <root directory of the eval>/02/frame1.png
35
+ frame2.png
36
+ <root directory of the eval>/03/frame1.png
37
+ frame2.png
38
+ ...
39
+
40
+ And will produce:
41
+ <root directory of the eval>/01/interpolated_frames/frame0.png
42
+ frame1.png
43
+ frame2.png
44
+ <root directory of the eval>/02/interpolated_frames/frame0.png
45
+ frame1.png
46
+ frame2.png
47
+ <root directory of the eval>/03/interpolated_frames/frame0.png
48
+ frame1.png
49
+ frame2.png
50
+ ...
51
+
52
+ And optionally will produce:
53
+ <root directory of the eval>/01/interpolated.mp4
54
+ <root directory of the eval>/02/interpolated.mp4
55
+ <root directory of the eval>/03/interpolated.mp4
56
+ ...
57
+
58
+ Usage example:
59
+ python3 -m frame_interpolation.eval.interpolator_cli \
60
+ --model_path <path to TF2 saved model> \
61
+ --pattern "<root directory of the eval>/*" \
62
+ --times_to_interpolate <Number of times to interpolate>
63
+ """
64
+
65
+ import functools
66
+ import os
67
+ from typing import List, Sequence
68
+
69
+ from . import interpolator as interpolator_lib
70
+ from . import util
71
+ from absl import app
72
+ from absl import flags
73
+ from absl import logging
74
+ import apache_beam as beam
75
+ import mediapy as media
76
+ import natsort
77
+ import numpy as np
78
+ import tensorflow as tf
79
+ from tqdm.auto import tqdm
80
+
81
+ # Controls TF_CCP log level.
82
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
83
+
84
+
85
+ _PATTERN = flags.DEFINE_string(
86
+ name='pattern',
87
+ default=None,
88
+ help='The pattern to determine the directories with the input frames.',
89
+ required=True)
90
+ _MODEL_PATH = flags.DEFINE_string(
91
+ name='model_path',
92
+ default=None,
93
+ help='The path of the TF2 saved model to use.')
94
+ _TIMES_TO_INTERPOLATE = flags.DEFINE_integer(
95
+ name='times_to_interpolate',
96
+ default=5,
97
+ help='The number of times to run recursive midpoint interpolation. '
98
+ 'The number of output frames will be 2^times_to_interpolate+1.')
99
+ _FPS = flags.DEFINE_integer(
100
+ name='fps',
101
+ default=30,
102
+ help='Frames per second to play interpolated videos in slow motion.')
103
+ _ALIGN = flags.DEFINE_integer(
104
+ name='align',
105
+ default=64,
106
+ help='If >1, pad the input size so it is evenly divisible by this value.')
107
+ _BLOCK_HEIGHT = flags.DEFINE_integer(
108
+ name='block_height',
109
+ default=1,
110
+ help='An int >= 1, number of patches along height, '
111
+ 'patch_height = height//block_height, should be evenly divisible.')
112
+ _BLOCK_WIDTH = flags.DEFINE_integer(
113
+ name='block_width',
114
+ default=1,
115
+ help='An int >= 1, number of patches along width, '
116
+ 'patch_width = width//block_width, should be evenly divisible.')
117
+ _OUTPUT_VIDEO = flags.DEFINE_boolean(
118
+ name='output_video',
119
+ default=False,
120
+ help='If true, creates a video of the frames in the interpolated_frames/ '
121
+ 'subdirectory')
122
+
123
+ # Add other extensions, if not either.
124
+ _INPUT_EXT = ['png', 'jpg', 'jpeg']
125
+
126
+
127
+ def _output_frames(frames: List[np.ndarray], frames_dir: str):
128
+ """Writes PNG-images to a directory.
129
+
130
+ If frames_dir doesn't exist, it is created. If frames_dir contains existing
131
+ PNG-files, they are removed before saving the new ones.
132
+
133
+ Args:
134
+ frames: List of images to save.
135
+ frames_dir: The output directory to save the images.
136
+
137
+ """
138
+ if tf.io.gfile.isdir(frames_dir):
139
+ old_frames = tf.io.gfile.glob(f'{frames_dir}/frame_*.png')
140
+ if old_frames:
141
+ logging.info('Removing existing frames from %s.', frames_dir)
142
+ for old_frame in old_frames:
143
+ tf.io.gfile.remove(old_frame)
144
+ else:
145
+ tf.io.gfile.makedirs(frames_dir)
146
+ for idx, frame in tqdm(
147
+ enumerate(frames), total=len(frames), ncols=100, colour='green'):
148
+ util.write_image(f'{frames_dir}/frame_{idx:03d}.png', frame)
149
+ logging.info('Output frames saved in %s.', frames_dir)
150
+
151
+
152
+ class ProcessDirectory(beam.DoFn):
153
+ """DoFn for running the interpolator on a single directory at the time."""
154
+
155
+ def setup(self):
156
+ self.interpolator = interpolator_lib.Interpolator(
157
+ _MODEL_PATH.value, _ALIGN.value,
158
+ [_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value])
159
+
160
+ if _OUTPUT_VIDEO.value:
161
+ ffmpeg_path = util.get_ffmpeg_path()
162
+ media.set_ffmpeg(ffmpeg_path)
163
+
164
+ def process(self, directory: str):
165
+ input_frames_list = [
166
+ natsort.natsorted(tf.io.gfile.glob(f'{directory}/*.{ext}'))
167
+ for ext in _INPUT_EXT
168
+ ]
169
+ input_frames = functools.reduce(lambda x, y: x + y, input_frames_list)
170
+ logging.info('Generating in-between frames for %s.', directory)
171
+ frames = list(
172
+ util.interpolate_recursively_from_files(
173
+ input_frames, _TIMES_TO_INTERPOLATE.value, self.interpolator))
174
+ _output_frames(frames, f'{directory}/interpolated_frames')
175
+ if _OUTPUT_VIDEO.value:
176
+ media.write_video(f'{directory}/interpolated.mp4', frames, fps=_FPS.value)
177
+ logging.info('Output video saved at %s/interpolated.mp4.', directory)
178
+
179
+
180
+ def _run_pipeline() -> None:
181
+ directories = tf.io.gfile.glob(_PATTERN.value)
182
+ pipeline = beam.Pipeline('DirectRunner')
183
+ (pipeline | 'Create directory names' >> beam.Create(directories) # pylint: disable=expression-not-assigned
184
+ | 'Process directories' >> beam.ParDo(ProcessDirectory()))
185
+
186
+ result = pipeline.run()
187
+ result.wait_until_finish()
188
+
189
+
190
+ def main(argv: Sequence[str]) -> None:
191
+ if len(argv) > 1:
192
+ raise app.UsageError('Too many command-line arguments.')
193
+ _run_pipeline()
194
+
195
+
196
+ if __name__ == '__main__':
197
+ app.run(main)
eval/interpolator_test.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""A test script for mid frame interpolation from two input frames.
16
+
17
+ Usage example:
18
+ python3 -m frame_interpolation.eval.interpolator_test \
19
+ --frame1 <filepath of the first frame> \
20
+ --frame2 <filepath of the second frame> \
21
+ --model_path <The filepath of the TF2 saved model to use>
22
+
23
+ The output is saved to <the directory of the input frames>/output_frame.png. If
24
+ `--output_frame` filepath is provided, it will be used instead.
25
+ """
26
+ import os
27
+ from typing import Sequence
28
+
29
+ from . import interpolator as interpolator_lib
30
+ from . import util
31
+ from absl import app
32
+ from absl import flags
33
+ import numpy as np
34
+
35
+ # Controls TF_CCP log level.
36
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
37
+
38
+
39
+ _FRAME1 = flags.DEFINE_string(
40
+ name='frame1',
41
+ default=None,
42
+ help='The filepath of the first input frame.',
43
+ required=True)
44
+ _FRAME2 = flags.DEFINE_string(
45
+ name='frame2',
46
+ default=None,
47
+ help='The filepath of the second input frame.',
48
+ required=True)
49
+ _MODEL_PATH = flags.DEFINE_string(
50
+ name='model_path',
51
+ default=None,
52
+ help='The path of the TF2 saved model to use.')
53
+ _OUTPUT_FRAME = flags.DEFINE_string(
54
+ name='output_frame',
55
+ default=None,
56
+ help='The output filepath of the interpolated mid-frame.')
57
+ _ALIGN = flags.DEFINE_integer(
58
+ name='align',
59
+ default=64,
60
+ help='If >1, pad the input size so it is evenly divisible by this value.')
61
+ _BLOCK_HEIGHT = flags.DEFINE_integer(
62
+ name='block_height',
63
+ default=1,
64
+ help='An int >= 1, number of patches along height, '
65
+ 'patch_height = height//block_height, should be evenly divisible.')
66
+ _BLOCK_WIDTH = flags.DEFINE_integer(
67
+ name='block_width',
68
+ default=1,
69
+ help='An int >= 1, number of patches along width, '
70
+ 'patch_width = width//block_width, should be evenly divisible.')
71
+
72
+
73
+ def _run_interpolator() -> None:
74
+ """Writes interpolated mid frame from a given two input frame filepaths."""
75
+
76
+ interpolator = interpolator_lib.Interpolator(
77
+ model_path=_MODEL_PATH.value,
78
+ align=_ALIGN.value,
79
+ block_shape=[_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value])
80
+
81
+ # First batched image.
82
+ image_1 = util.read_image(_FRAME1.value)
83
+ image_batch_1 = np.expand_dims(image_1, axis=0)
84
+
85
+ # Second batched image.
86
+ image_2 = util.read_image(_FRAME2.value)
87
+ image_batch_2 = np.expand_dims(image_2, axis=0)
88
+
89
+ # Batched time.
90
+ batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
91
+
92
+ # Invoke the model for one mid-frame interpolation.
93
+ mid_frame = interpolator(image_batch_1, image_batch_2, batch_dt)[0]
94
+
95
+ # Write interpolated mid-frame.
96
+ mid_frame_filepath = _OUTPUT_FRAME.value
97
+ if not mid_frame_filepath:
98
+ mid_frame_filepath = f'{os.path.dirname(_FRAME1.value)}/output_frame.png'
99
+ util.write_image(mid_frame_filepath, mid_frame)
100
+
101
+
102
+ def main(argv: Sequence[str]) -> None:
103
+ if len(argv) > 1:
104
+ raise app.UsageError('Too many command-line arguments.')
105
+ _run_interpolator()
106
+
107
+
108
+ if __name__ == '__main__':
109
+ app.run(main)
eval/util.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for frame interpolation on a set of video frames."""
16
+ import os
17
+ import shutil
18
+ from typing import Generator, Iterable, List, Optional
19
+
20
+ from . import interpolator as interpolator_lib
21
+ import numpy as np
22
+ import tensorflow as tf
23
+ from tqdm import tqdm
24
+
25
+ _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
26
+ _CONFIG_FFMPEG_NAME_OR_PATH = 'ffmpeg'
27
+
28
+
29
+ def read_image(filename: str) -> np.ndarray:
30
+ """Reads an sRgb 8-bit image.
31
+
32
+ Args:
33
+ filename: The input filename to read.
34
+
35
+ Returns:
36
+ A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
37
+ """
38
+ image_data = tf.io.read_file(filename)
39
+ image = tf.io.decode_image(image_data, channels=3)
40
+ image_numpy = tf.cast(image, dtype=tf.float32).numpy()
41
+ return image_numpy / _UINT8_MAX_F
42
+
43
+
44
+ def write_image(filename: str, image: np.ndarray) -> None:
45
+ """Writes a float32 3-channel RGB ndarray image, with colors in range [0..1].
46
+
47
+ Args:
48
+ filename: The output filename to save.
49
+ image: A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
50
+ """
51
+ image_in_uint8_range = np.clip(image * _UINT8_MAX_F, 0.0, _UINT8_MAX_F)
52
+ image_in_uint8 = (image_in_uint8_range + 0.5).astype(np.uint8)
53
+
54
+ extension = os.path.splitext(filename)[1]
55
+ if extension == '.jpg':
56
+ image_data = tf.io.encode_jpeg(image_in_uint8)
57
+ else:
58
+ image_data = tf.io.encode_png(image_in_uint8)
59
+ tf.io.write_file(filename, image_data)
60
+
61
+
62
+ def _recursive_generator(
63
+ frame1: np.ndarray, frame2: np.ndarray, num_recursions: int,
64
+ interpolator: interpolator_lib.Interpolator,
65
+ bar: Optional[tqdm] = None
66
+ ) -> Generator[np.ndarray, None, None]:
67
+ """Splits halfway to repeatedly generate more frames.
68
+
69
+ Args:
70
+ frame1: Input image 1.
71
+ frame2: Input image 2.
72
+ num_recursions: How many times to interpolate the consecutive image pairs.
73
+ interpolator: The frame interpolator instance.
74
+
75
+ Yields:
76
+ The interpolated frames, including the first frame (frame1), but excluding
77
+ the final frame2.
78
+ """
79
+ if num_recursions == 0:
80
+ yield frame1
81
+ else:
82
+ # Adds the batch dimension to all inputs before calling the interpolator,
83
+ # and remove it afterwards.
84
+ time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
85
+ mid_frame = interpolator(frame1[np.newaxis, ...], frame2[np.newaxis, ...],
86
+ time)[0]
87
+ bar.update(1) if bar is not None else bar
88
+ yield from _recursive_generator(frame1, mid_frame, num_recursions - 1,
89
+ interpolator, bar)
90
+ yield from _recursive_generator(mid_frame, frame2, num_recursions - 1,
91
+ interpolator, bar)
92
+
93
+
94
+ def interpolate_recursively_from_files(
95
+ frames: List[str], times_to_interpolate: int,
96
+ interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
97
+ """Generates interpolated frames by repeatedly interpolating the midpoint.
98
+
99
+ Loads the files on demand and uses the yield paradigm to return the frames
100
+ to allow streamed processing of longer videos.
101
+
102
+ Recursive interpolation is useful if the interpolator is trained to predict
103
+ frames at midpoint only and is thus expected to perform poorly elsewhere.
104
+
105
+ Args:
106
+ frames: List of input frames. Expected shape (H, W, 3). The colors should be
107
+ in the range[0, 1] and in gamma space.
108
+ times_to_interpolate: Number of times to do recursive midpoint
109
+ interpolation.
110
+ interpolator: The frame interpolation model to use.
111
+
112
+ Yields:
113
+ The interpolated frames (including the inputs).
114
+ """
115
+ n = len(frames)
116
+ num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
117
+ bar = tqdm(total=num_frames, ncols=100, colour='green')
118
+ for i in range(1, n):
119
+ yield from _recursive_generator(
120
+ read_image(frames[i - 1]), read_image(frames[i]), times_to_interpolate,
121
+ interpolator, bar)
122
+ # Separately yield the final frame.
123
+ yield read_image(frames[-1])
124
+
125
+ def interpolate_recursively_from_memory(
126
+ frames: List[np.ndarray], times_to_interpolate: int,
127
+ interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
128
+ """Generates interpolated frames by repeatedly interpolating the midpoint.
129
+
130
+ This is functionally equivalent to interpolate_recursively_from_files(), but
131
+ expects the inputs frames in memory, instead of loading them on demand.
132
+
133
+ Recursive interpolation is useful if the interpolator is trained to predict
134
+ frames at midpoint only and is thus expected to perform poorly elsewhere.
135
+
136
+ Args:
137
+ frames: List of input frames. Expected shape (H, W, 3). The colors should be
138
+ in the range[0, 1] and in gamma space.
139
+ times_to_interpolate: Number of times to do recursive midpoint
140
+ interpolation.
141
+ interpolator: The frame interpolation model to use.
142
+
143
+ Yields:
144
+ The interpolated frames (including the inputs).
145
+ """
146
+ n = len(frames)
147
+ num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
148
+ bar = tqdm(total=num_frames, ncols=100, colour='green')
149
+ for i in range(1, n):
150
+ yield from _recursive_generator(frames[i - 1], frames[i],
151
+ times_to_interpolate, interpolator, bar)
152
+ # Separately yield the final frame.
153
+ yield frames[-1]
154
+
155
+
156
+ def get_ffmpeg_path() -> str:
157
+ path = shutil.which(_CONFIG_FFMPEG_NAME_OR_PATH)
158
+ if not path:
159
+ raise RuntimeError(
160
+ f"Program '{_CONFIG_FFMPEG_NAME_OR_PATH}' is not found;"
161
+ " perhaps install ffmpeg using 'apt-get install ffmpeg'.")
162
+ return path
losses/losses.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Loss functions used to train the FILM interpolation model.
16
+
17
+ The losses for training and test loops are configurable via gin. Training can
18
+ use more than one loss function. Test loop can also evaluate one ore more loss
19
+ functions, each of which can be summarized separately.
20
+ """
21
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
22
+
23
+ from . import vgg19_loss as vgg19
24
+ import gin.tf
25
+ import numpy as np
26
+ import tensorflow as tf
27
+
28
+
29
+ @gin.configurable('vgg', denylist=['example', 'prediction'])
30
+ def vgg_loss(example: Mapping[str, tf.Tensor],
31
+ prediction: Mapping[str, tf.Tensor],
32
+ vgg_model_file: str,
33
+ weights: Optional[List[float]] = None) -> tf.Tensor:
34
+ """Perceptual loss for images in [0,1] color range.
35
+
36
+ Args:
37
+ example: A dictionary with the ground truth image as 'y'.
38
+ prediction: The prediction dictionary with the image as 'image'.
39
+ vgg_model_file: The path containing the vgg19 weights in MATLAB format.
40
+ weights: An optional array of weights for different VGG layers. If None, the
41
+ default weights are used (see vgg19.vgg_loss documentation).
42
+
43
+ Returns:
44
+ The perceptual loss.
45
+ """
46
+ return vgg19.vgg_loss(prediction['image'], example['y'], vgg_model_file,
47
+ weights)
48
+
49
+
50
+ @gin.configurable('style', denylist=['example', 'prediction'])
51
+ def style_loss(example: Mapping[str, tf.Tensor],
52
+ prediction: Mapping[str, tf.Tensor],
53
+ vgg_model_file: str,
54
+ weights: Optional[List[float]] = None) -> tf.Tensor:
55
+ """Computes style loss from images in [0..1] color range.
56
+
57
+ Args:
58
+ example: A dictionary with the ground truth image as 'y'.
59
+ prediction: The prediction dictionary with the image as 'image'.
60
+ vgg_model_file: The path containing the vgg19 weights in MATLAB format.
61
+ weights: An optional array of weights for different VGG layers. If None, the
62
+ default weights are used (see vgg19.vgg_loss documentation).
63
+
64
+ Returns:
65
+ A tf.Tensor of a scalar representing the style loss computed over multiple
66
+ vgg layer features.
67
+ """
68
+ return vgg19.style_loss(prediction['image'], example['y'], vgg_model_file,
69
+ weights)
70
+
71
+
72
+ def l1_loss(example: Mapping[str, tf.Tensor],
73
+ prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
74
+ return tf.reduce_mean(tf.abs(prediction['image'] - example['y']))
75
+
76
+
77
+ def l1_warped_loss(example: Mapping[str, tf.Tensor],
78
+ prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
79
+ """Computes an l1 loss using only warped images.
80
+
81
+ Args:
82
+ example: A dictionary with the ground truth image as 'y'.
83
+ prediction: The prediction dictionary with the image(s) as 'x0_warped'
84
+ and/or 'x1_warped'.
85
+
86
+ Returns:
87
+ A tf.Tensor of a scalar representing the linear combination of l1 losses
88
+ between prediction images and y.
89
+ """
90
+ loss = tf.constant(0.0, dtype=tf.float32)
91
+ if 'x0_warped' in prediction:
92
+ loss += tf.reduce_mean(tf.abs(prediction['x0_warped'] - example['y']))
93
+ if 'x1_warped' in prediction:
94
+ loss += tf.reduce_mean(tf.abs(prediction['x1_warped'] - example['y']))
95
+ return loss
96
+
97
+
98
+ def l2_loss(example: Mapping[str, tf.Tensor],
99
+ prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
100
+ return tf.reduce_mean(tf.square(prediction['image'] - example['y']))
101
+
102
+
103
+ def ssim_loss(example: Mapping[str, tf.Tensor],
104
+ prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
105
+ image = prediction['image']
106
+ y = example['y']
107
+ return tf.reduce_mean(tf.image.ssim(image, y, max_val=1.0))
108
+
109
+
110
+ def psnr_loss(example: Mapping[str, tf.Tensor],
111
+ prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
112
+ return tf.reduce_mean(
113
+ tf.image.psnr(prediction['image'], example['y'], max_val=1.0))
114
+
115
+
116
+ def get_loss(loss_name: str) -> Callable[[Any, Any], tf.Tensor]:
117
+ """Returns the loss function corresponding to the given name."""
118
+ if loss_name == 'l1':
119
+ return l1_loss
120
+ elif loss_name == 'l2':
121
+ return l2_loss
122
+ elif loss_name == 'ssim':
123
+ return ssim_loss
124
+ elif loss_name == 'vgg':
125
+ return vgg_loss
126
+ elif loss_name == 'style':
127
+ return style_loss
128
+ elif loss_name == 'psnr':
129
+ return psnr_loss
130
+ elif loss_name == 'l1_warped':
131
+ return l1_warped_loss
132
+ else:
133
+ raise ValueError('Invalid loss function %s' % loss_name)
134
+
135
+
136
+ # pylint: disable=unnecessary-lambda
137
+ def get_loss_op(loss_name):
138
+ """Returns a function for creating a loss calculation op."""
139
+ loss = get_loss(loss_name)
140
+ return lambda example, prediction: loss(example, prediction)
141
+
142
+
143
+ def get_weight_op(weight_schedule):
144
+ """Returns a function for creating an iteration dependent loss weight op."""
145
+ return lambda iterations: weight_schedule(iterations)
146
+
147
+
148
+ def create_losses(
149
+ loss_names: List[str], loss_weight_schedules: List[
150
+ tf.keras.optimizers.schedules.LearningRateSchedule]
151
+ ) -> Dict[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
152
+ tf.Tensor]]]:
153
+ """Returns a dictionary of functions for creating loss and loss_weight ops.
154
+
155
+ As an example, create_losses(['l1', 'l2'], [PiecewiseConstantDecay(),
156
+ PiecewiseConstantDecay()]) returns a dictionary with two keys, and each value
157
+ being a tuple of ops for loss calculation and loss_weight sampling.
158
+
159
+ Args:
160
+ loss_names: Names of the losses.
161
+ loss_weight_schedules: Instances of loss weight schedules.
162
+
163
+ Returns:
164
+ A dictionary that contains the loss and weight schedule ops keyed by the
165
+ names.
166
+ """
167
+ losses = dict()
168
+ for name, weight_schedule in zip(loss_names, loss_weight_schedules):
169
+ unique_values = np.unique(weight_schedule.values)
170
+ if len(unique_values) == 1 and unique_values[0] == 1.0:
171
+ # Special case 'no weight' for prettier TensorBoard summaries.
172
+ weighted_name = name
173
+ else:
174
+ # Weights are variable/scheduled, a constant "k" is used to
175
+ # indicate weights are iteration dependent.
176
+ weighted_name = 'k*' + name
177
+ losses[weighted_name] = (get_loss_op(name), get_weight_op(weight_schedule))
178
+ return losses
179
+
180
+
181
+ @gin.configurable
182
+ def training_losses(
183
+ loss_names: List[str],
184
+ loss_weights: Optional[List[float]] = None,
185
+ loss_weight_schedules: Optional[List[
186
+ tf.keras.optimizers.schedules.LearningRateSchedule]] = None,
187
+ loss_weight_parameters: Optional[List[Mapping[str, List[Any]]]] = None
188
+ ) -> Mapping[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
189
+ tf.Tensor]]]:
190
+ """Creates the training loss functions and loss weight schedules."""
191
+ weight_schedules = []
192
+ if not loss_weights:
193
+ for weight_schedule, weight_parameters in zip(loss_weight_schedules,
194
+ loss_weight_parameters):
195
+ weight_schedules.append(weight_schedule(**weight_parameters))
196
+ else:
197
+ for loss_weight in loss_weights:
198
+ weight_parameters = {
199
+ 'boundaries': [0],
200
+ 'values': 2 * [
201
+ loss_weight,
202
+ ]
203
+ }
204
+ weight_schedules.append(
205
+ tf.keras.optimizers.schedules.PiecewiseConstantDecay(
206
+ **weight_parameters))
207
+
208
+ return create_losses(loss_names, weight_schedules)
209
+
210
+
211
+ @gin.configurable
212
+ def test_losses(
213
+ loss_names: List[str],
214
+ loss_weights: Optional[List[float]] = None,
215
+ loss_weight_schedules: Optional[List[
216
+ tf.keras.optimizers.schedules.LearningRateSchedule]] = None,
217
+ loss_weight_parameters: Optional[List[Mapping[str, List[Any]]]] = None
218
+ ) -> Mapping[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
219
+ tf.Tensor]]]:
220
+ """Creates the test loss functions and loss weight schedules."""
221
+ weight_schedules = []
222
+ if not loss_weights:
223
+ for weight_schedule, weight_parameters in zip(loss_weight_schedules,
224
+ loss_weight_parameters):
225
+ weight_schedules.append(weight_schedule(**weight_parameters))
226
+ else:
227
+ for loss_weight in loss_weights:
228
+ weight_parameters = {
229
+ 'boundaries': [0],
230
+ 'values': 2 * [
231
+ loss_weight,
232
+ ]
233
+ }
234
+ weight_schedules.append(
235
+ tf.keras.optimizers.schedules.PiecewiseConstantDecay(
236
+ **weight_parameters))
237
+
238
+ return create_losses(loss_names, weight_schedules)
239
+
240
+
241
+ def aggregate_batch_losses(
242
+ batch_losses: List[Mapping[str, float]]) -> Mapping[str, float]:
243
+ """Averages per batch losses into single dictionary for the whole epoch.
244
+
245
+ As an example, if the batch_losses contained per batch losses:
246
+ batch_losses = { {'l1': 0.2, 'ssim': 0.9}, {'l1': 0.3, 'ssim': 0.8}}
247
+ The returned dictionary would look like: { 'l1': 0.25, 'ssim': 0.95 }
248
+
249
+ Args:
250
+ batch_losses: A list of dictionary objects, with one entry for each loss.
251
+
252
+ Returns:
253
+ Single dictionary with the losses aggregated.
254
+ """
255
+ transp_losses = {}
256
+ # Loop through all losses
257
+ for batch_loss in batch_losses:
258
+ # Loop through per batch losses of a single type:
259
+ for loss_name, loss in batch_loss.items():
260
+ if loss_name not in transp_losses:
261
+ transp_losses[loss_name] = []
262
+ transp_losses[loss_name].append(loss)
263
+ aggregate_losses = {}
264
+ for loss_name in transp_losses:
265
+ aggregate_losses[loss_name] = np.mean(transp_losses[loss_name])
266
+ return aggregate_losses
losses/vgg19_loss.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Feature loss based on 19 layer VGG network.
16
+
17
+
18
+ The network layers in the feature loss is weighted as described in
19
+ 'Stereo Magnification: Learning View Synthesis using Multiplane Images',
20
+ Tinghui Zhou, Richard Tucker, Flynn, Graham Fyffe, Noah Snavely, SIGGRAPH 2018.
21
+ """
22
+
23
+ from typing import Any, Callable, Dict, Optional, Sequence, Tuple
24
+
25
+ import numpy as np
26
+ import scipy.io as sio
27
+ import tensorflow.compat.v1 as tf
28
+
29
+
30
+ def _build_net(layer_type: str,
31
+ input_tensor: tf.Tensor,
32
+ weight_bias: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
33
+ name: Optional[str] = None) -> Callable[[Any], Any]:
34
+ """Build a layer of the VGG network.
35
+
36
+ Args:
37
+ layer_type: A string, type of this layer.
38
+ input_tensor: A tensor.
39
+ weight_bias: A tuple of weight and bias.
40
+ name: A string, name of this layer.
41
+
42
+ Returns:
43
+ A callable function of the tensorflow layer.
44
+
45
+ Raises:
46
+ ValueError: If layer_type is not conv or pool.
47
+ """
48
+
49
+ if layer_type == 'conv':
50
+ return tf.nn.relu(
51
+ tf.nn.conv2d(
52
+ input_tensor,
53
+ weight_bias[0],
54
+ strides=[1, 1, 1, 1],
55
+ padding='SAME',
56
+ name=name) + weight_bias[1])
57
+ elif layer_type == 'pool':
58
+ return tf.nn.avg_pool(
59
+ input_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
60
+ else:
61
+ raise ValueError('Unsupported layer %s' % layer_type)
62
+
63
+
64
+ def _get_weight_and_bias(vgg_layers: np.ndarray,
65
+ index: int) -> Tuple[tf.Tensor, tf.Tensor]:
66
+ """Get the weight and bias of a specific layer from the VGG pretrained model.
67
+
68
+ Args:
69
+ vgg_layers: An array, the VGG pretrained model.
70
+ index: An integer, index of the layer.
71
+
72
+ Returns:
73
+ weights: A tensor.
74
+ bias: A tensor.
75
+ """
76
+
77
+ weights = vgg_layers[index][0][0][2][0][0]
78
+ weights = tf.constant(weights)
79
+ bias = vgg_layers[index][0][0][2][0][1]
80
+ bias = tf.constant(np.reshape(bias, (bias.size)))
81
+
82
+ return weights, bias
83
+
84
+
85
+ def _build_vgg19(image: tf.Tensor, model_filepath: str) -> Dict[str, tf.Tensor]:
86
+ """Builds the VGG network given the model weights.
87
+
88
+ The weights are loaded only for the first time this code is invoked.
89
+
90
+ Args:
91
+ image: A tensor, input image.
92
+ model_filepath: A string, path to the VGG pretrained model.
93
+
94
+ Returns:
95
+ net: A dict mapping a layer name to a tensor.
96
+ """
97
+
98
+ with tf.variable_scope('vgg', reuse=True):
99
+ net = {}
100
+ if not hasattr(_build_vgg19, 'vgg_rawnet'):
101
+ with tf.io.gfile.GFile(model_filepath, 'rb') as f:
102
+ _build_vgg19.vgg_rawnet = sio.loadmat(f)
103
+ vgg_layers = _build_vgg19.vgg_rawnet['layers'][0]
104
+ imagenet_mean = tf.constant([123.6800, 116.7790, 103.9390],
105
+ shape=[1, 1, 1, 3])
106
+ net['input'] = image - imagenet_mean
107
+ net['conv1_1'] = _build_net(
108
+ 'conv',
109
+ net['input'],
110
+ _get_weight_and_bias(vgg_layers, 0),
111
+ name='vgg_conv1_1')
112
+ net['conv1_2'] = _build_net(
113
+ 'conv',
114
+ net['conv1_1'],
115
+ _get_weight_and_bias(vgg_layers, 2),
116
+ name='vgg_conv1_2')
117
+ net['pool1'] = _build_net('pool', net['conv1_2'])
118
+ net['conv2_1'] = _build_net(
119
+ 'conv',
120
+ net['pool1'],
121
+ _get_weight_and_bias(vgg_layers, 5),
122
+ name='vgg_conv2_1')
123
+ net['conv2_2'] = _build_net(
124
+ 'conv',
125
+ net['conv2_1'],
126
+ _get_weight_and_bias(vgg_layers, 7),
127
+ name='vgg_conv2_2')
128
+ net['pool2'] = _build_net('pool', net['conv2_2'])
129
+ net['conv3_1'] = _build_net(
130
+ 'conv',
131
+ net['pool2'],
132
+ _get_weight_and_bias(vgg_layers, 10),
133
+ name='vgg_conv3_1')
134
+ net['conv3_2'] = _build_net(
135
+ 'conv',
136
+ net['conv3_1'],
137
+ _get_weight_and_bias(vgg_layers, 12),
138
+ name='vgg_conv3_2')
139
+ net['conv3_3'] = _build_net(
140
+ 'conv',
141
+ net['conv3_2'],
142
+ _get_weight_and_bias(vgg_layers, 14),
143
+ name='vgg_conv3_3')
144
+ net['conv3_4'] = _build_net(
145
+ 'conv',
146
+ net['conv3_3'],
147
+ _get_weight_and_bias(vgg_layers, 16),
148
+ name='vgg_conv3_4')
149
+ net['pool3'] = _build_net('pool', net['conv3_4'])
150
+ net['conv4_1'] = _build_net(
151
+ 'conv',
152
+ net['pool3'],
153
+ _get_weight_and_bias(vgg_layers, 19),
154
+ name='vgg_conv4_1')
155
+ net['conv4_2'] = _build_net(
156
+ 'conv',
157
+ net['conv4_1'],
158
+ _get_weight_and_bias(vgg_layers, 21),
159
+ name='vgg_conv4_2')
160
+ net['conv4_3'] = _build_net(
161
+ 'conv',
162
+ net['conv4_2'],
163
+ _get_weight_and_bias(vgg_layers, 23),
164
+ name='vgg_conv4_3')
165
+ net['conv4_4'] = _build_net(
166
+ 'conv',
167
+ net['conv4_3'],
168
+ _get_weight_and_bias(vgg_layers, 25),
169
+ name='vgg_conv4_4')
170
+ net['pool4'] = _build_net('pool', net['conv4_4'])
171
+ net['conv5_1'] = _build_net(
172
+ 'conv',
173
+ net['pool4'],
174
+ _get_weight_and_bias(vgg_layers, 28),
175
+ name='vgg_conv5_1')
176
+ net['conv5_2'] = _build_net(
177
+ 'conv',
178
+ net['conv5_1'],
179
+ _get_weight_and_bias(vgg_layers, 30),
180
+ name='vgg_conv5_2')
181
+
182
+ return net
183
+
184
+
185
+ def _compute_error(fake: tf.Tensor,
186
+ real: tf.Tensor,
187
+ mask: Optional[tf.Tensor] = None) -> tf.Tensor:
188
+ """Computes the L1 loss and reweights by the mask."""
189
+ if mask is None:
190
+ return tf.reduce_mean(tf.abs(fake - real))
191
+ else:
192
+ # Resizes mask to the same size as the input.
193
+ size = (tf.shape(fake)[1], tf.shape(fake)[2])
194
+ resized_mask = tf.image.resize(
195
+ mask, size, method=tf.image.ResizeMethod.BILINEAR)
196
+ return tf.reduce_mean(tf.abs(fake - real) * resized_mask)
197
+
198
+
199
+ # Normalized VGG loss (from
200
+ # https://github.com/CQFIO/PhotographicImageSynthesis)
201
+ def vgg_loss(image: tf.Tensor,
202
+ reference: tf.Tensor,
203
+ vgg_model_file: str,
204
+ weights: Optional[Sequence[float]] = None,
205
+ mask: Optional[tf.Tensor] = None) -> tf.Tensor:
206
+ """Computes the VGG loss for an image pair.
207
+
208
+ The VGG loss is the average feature vector difference between the two images.
209
+
210
+ The input images must be in [0, 1] range in (B, H, W, 3) RGB format and
211
+ the recommendation seems to be to have them in gamma space.
212
+
213
+ The pretrained weights are publicly available in
214
+ http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat
215
+
216
+ Args:
217
+ image: A tensor, typically the prediction from a network.
218
+ reference: A tensor, the image to compare against, i.e. the golden image.
219
+ vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB
220
+ format.
221
+ weights: A list of float, optional weights for the layers. The defaults are
222
+ from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with
223
+ cascaded refinement networks," ICCV 2017.
224
+ mask: An optional image-shape and single-channel tensor, the mask values are
225
+ per-pixel weights to be applied on the losses. The mask will be resized to
226
+ the same spatial resolution with the feature maps before been applied to
227
+ the losses. When the mask value is zero, pixels near the boundary of the
228
+ mask can still influence the loss if they fall into the receptive field of
229
+ the VGG convolutional layers.
230
+
231
+ Returns:
232
+ vgg_loss: The linear combination of losses from five VGG layers.
233
+ """
234
+
235
+ if not weights:
236
+ weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]
237
+
238
+ vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file)
239
+ vgg_img = _build_vgg19(image * 255.0, vgg_model_file)
240
+ p1 = _compute_error(vgg_ref['conv1_2'], vgg_img['conv1_2'], mask) * weights[0]
241
+ p2 = _compute_error(vgg_ref['conv2_2'], vgg_img['conv2_2'], mask) * weights[1]
242
+ p3 = _compute_error(vgg_ref['conv3_2'], vgg_img['conv3_2'], mask) * weights[2]
243
+ p4 = _compute_error(vgg_ref['conv4_2'], vgg_img['conv4_2'], mask) * weights[3]
244
+ p5 = _compute_error(vgg_ref['conv5_2'], vgg_img['conv5_2'], mask) * weights[4]
245
+
246
+ final_loss = p1 + p2 + p3 + p4 + p5
247
+
248
+ # Scale to range [0..1].
249
+ final_loss /= 255.0
250
+
251
+ return final_loss
252
+
253
+
254
+ def _compute_gram_matrix(input_features: tf.Tensor,
255
+ mask: tf.Tensor) -> tf.Tensor:
256
+ """Computes Gram matrix of `input_features`.
257
+
258
+ Gram matrix described in https://en.wikipedia.org/wiki/Gramian_matrix.
259
+
260
+ Args:
261
+ input_features: A tf.Tensor of shape (B, H, W, C) representing a feature map
262
+ obtained by a convolutional layer of a VGG network.
263
+ mask: A tf.Tensor of shape (B, H, W, 1) representing the per-pixel weights
264
+ to be applied on the `input_features`. The mask will be resized to the
265
+ same spatial resolution as the `input_featues`. When the mask value is
266
+ zero, pixels near the boundary of the mask can still influence the loss if
267
+ they fall into the receptive field of the VGG convolutional layers.
268
+
269
+ Returns:
270
+ A tf.Tensor of shape (B, C, C) representing the gram matrix of the masked
271
+ `input_features`.
272
+ """
273
+ _, h, w, c = tuple([
274
+ i if (isinstance(i, int) or i is None) else i.value
275
+ for i in input_features.shape
276
+ ])
277
+ if mask is None:
278
+ reshaped_features = tf.reshape(input_features, (-1, h * w, c))
279
+ else:
280
+ # Resize mask to match the shape of `input_features`
281
+ resized_mask = tf.image.resize(
282
+ mask, (h, w), method=tf.image.ResizeMethod.BILINEAR)
283
+ reshaped_features = tf.reshape(input_features * resized_mask,
284
+ (-1, h * w, c))
285
+ return tf.matmul(
286
+ reshaped_features, reshaped_features, transpose_a=True) / float(h * w)
287
+
288
+
289
+ def style_loss(image: tf.Tensor,
290
+ reference: tf.Tensor,
291
+ vgg_model_file: str,
292
+ weights: Optional[Sequence[float]] = None,
293
+ mask: Optional[tf.Tensor] = None) -> tf.Tensor:
294
+ """Computes style loss as used in `A Neural Algorithm of Artistic Style`.
295
+
296
+ Based on the work in https://github.com/cysmith/neural-style-tf. Weights are
297
+ first initilaized to the inverse of the number of elements in each VGG layer
298
+ considerd. After 1.5M iterations, they are rescaled to normalize the
299
+ contribution of the Style loss to be equal to other losses (L1/VGG). This is
300
+ based on the works of image inpainting (https://arxiv.org/abs/1804.07723)
301
+ and frame prediction (https://arxiv.org/abs/1811.00684).
302
+
303
+ The style loss is the average gram matrix difference between `image` and
304
+ `reference`. The gram matrix is the inner product of a feature map of shape
305
+ (B, H*W, C) with itself. Results in a symmetric gram matrix shaped (B, C, C).
306
+
307
+ The input images must be in [0, 1] range in (B, H, W, 3) RGB format and
308
+ the recommendation seems to be to have them in gamma space.
309
+
310
+ The pretrained weights are publicly available in
311
+ http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat
312
+
313
+ Args:
314
+ image: A tensor, typically the prediction from a network.
315
+ reference: A tensor, the image to compare against, i.e. the golden image.
316
+ vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB
317
+ format.
318
+ weights: A list of float, optional weights for the layers. The defaults are
319
+ from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with
320
+ cascaded refinement networks," ICCV 2017.
321
+ mask: An optional image-shape and single-channel tensor, the mask values are
322
+ per-pixel weights to be applied on the losses. The mask will be resized to
323
+ the same spatial resolution with the feature maps before been applied to
324
+ the losses. When the mask value is zero, pixels near the boundary of the
325
+ mask can still influence the loss if they fall into the receptive field of
326
+ the VGG convolutional layers.
327
+
328
+ Returns:
329
+ Style loss, a linear combination of gram matrix L2 differences of from five
330
+ VGG layer features.
331
+ """
332
+
333
+ if not weights:
334
+ weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]
335
+
336
+ vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file)
337
+ vgg_img = _build_vgg19(image * 255.0, vgg_model_file)
338
+
339
+ p1 = tf.reduce_mean(
340
+ tf.squared_difference(
341
+ _compute_gram_matrix(vgg_ref['conv1_2'] / 255.0, mask),
342
+ _compute_gram_matrix(vgg_img['conv1_2'] / 255.0, mask))) * weights[0]
343
+ p2 = tf.reduce_mean(
344
+ tf.squared_difference(
345
+ _compute_gram_matrix(vgg_ref['conv2_2'] / 255.0, mask),
346
+ _compute_gram_matrix(vgg_img['conv2_2'] / 255.0, mask))) * weights[1]
347
+ p3 = tf.reduce_mean(
348
+ tf.squared_difference(
349
+ _compute_gram_matrix(vgg_ref['conv3_2'] / 255.0, mask),
350
+ _compute_gram_matrix(vgg_img['conv3_2'] / 255.0, mask))) * weights[2]
351
+ p4 = tf.reduce_mean(
352
+ tf.squared_difference(
353
+ _compute_gram_matrix(vgg_ref['conv4_2'] / 255.0, mask),
354
+ _compute_gram_matrix(vgg_img['conv4_2'] / 255.0, mask))) * weights[3]
355
+ p5 = tf.reduce_mean(
356
+ tf.squared_difference(
357
+ _compute_gram_matrix(vgg_ref['conv5_2'] / 255.0, mask),
358
+ _compute_gram_matrix(vgg_img['conv5_2'] / 255.0, mask))) * weights[4]
359
+
360
+ final_loss = p1 + p2 + p3 + p4 + p5
361
+
362
+ return final_loss
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/film_net/feature_extractor.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """TF2 layer for extracting image features for the film_net interpolator.
16
+
17
+ The feature extractor implemented here converts an image pyramid into a pyramid
18
+ of deep features. The feature pyramid serves a similar purpose as U-Net
19
+ architecture's encoder, but we use a special cascaded architecture described in
20
+ Multi-view Image Fusion [1].
21
+
22
+ For comprehensiveness, below is a short description of the idea. While the
23
+ description is a bit involved, the cascaded feature pyramid can be used just
24
+ like any image feature pyramid.
25
+
26
+ Why cascaded architeture?
27
+ =========================
28
+ To understand the concept it is worth reviewing a traditional feature pyramid
29
+ first: *A traditional feature pyramid* as in U-net or in many optical flow
30
+ networks is built by alternating between convolutions and pooling, starting
31
+ from the input image.
32
+
33
+ It is well known that early features of such architecture correspond to low
34
+ level concepts such as edges in the image whereas later layers extract
35
+ semantically higher level concepts such as object classes etc. In other words,
36
+ the meaning of the filters in each resolution level is different. For problems
37
+ such as semantic segmentation and many others this is a desirable property.
38
+
39
+ However, the asymmetric features preclude sharing weights across resolution
40
+ levels in the feature extractor itself and in any subsequent neural networks
41
+ that follow. This can be a downside, since optical flow prediction, for
42
+ instance is symmetric across resolution levels. The cascaded feature
43
+ architecture addresses this shortcoming.
44
+
45
+ How is it built?
46
+ ================
47
+ The *cascaded* feature pyramid contains feature vectors that have constant
48
+ length and meaning on each resolution level, except few of the finest ones. The
49
+ advantage of this is that the subsequent optical flow layer can learn
50
+ synergically from many resolutions. This means that coarse level prediction can
51
+ benefit from finer resolution training examples, which can be useful with
52
+ moderately sized datasets to avoid overfitting.
53
+
54
+ The cascaded feature pyramid is built by extracting shallower subtree pyramids,
55
+ each one of them similar to the traditional architecture. Each subtree
56
+ pyramid S_i is extracted starting from each resolution level:
57
+
58
+ image resolution 0 -> S_0
59
+ image resolution 1 -> S_1
60
+ image resolution 2 -> S_2
61
+ ...
62
+
63
+ If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
64
+ is constructed by concatenating features as follows (assuming subtree depth=3):
65
+
66
+ lvl
67
+ feat_0 = concat( S_0_0 )
68
+ feat_1 = concat( S_1_0 S_0_1 )
69
+ feat_2 = concat( S_2_0 S_1_1 S_0_2 )
70
+ feat_3 = concat( S_3_0 S_2_1 S_1_2 )
71
+ feat_4 = concat( S_4_0 S_3_1 S_2_2 )
72
+ feat_5 = concat( S_5_0 S_4_1 S_3_2 )
73
+ ....
74
+
75
+ In above, all levels except feat_0 and feat_1 have the same number of features
76
+ with similar semantic meaning. This enables training a single optical flow
77
+ predictor module shared by levels 2,3,4,5... . For more details and evaluation
78
+ see [1].
79
+
80
+ [1] Multi-view Image Fusion, Trinidad et al. 2019
81
+ """
82
+
83
+ from typing import List
84
+
85
+ from . import options
86
+ import tensorflow as tf
87
+
88
+
89
+ def _relu(x: tf.Tensor) -> tf.Tensor:
90
+ return tf.nn.leaky_relu(x, alpha=0.2)
91
+
92
+
93
+ def _conv(filters: int, name: str):
94
+ return tf.keras.layers.Conv2D(
95
+ name=name,
96
+ filters=filters,
97
+ kernel_size=3,
98
+ padding='same',
99
+ activation=_relu)
100
+
101
+
102
+ class SubTreeExtractor(tf.keras.layers.Layer):
103
+ """Extracts a hierarchical set of features from an image.
104
+
105
+ This is a conventional, hierarchical image feature extractor, that extracts
106
+ [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
107
+ Each level is followed by average pooling.
108
+
109
+ Attributes:
110
+ name: Name for the layer
111
+ config: Options for the fusion_net frame interpolator
112
+ """
113
+
114
+ def __init__(self, name: str, config: options.Options):
115
+ super().__init__(name=name)
116
+ k = config.filters
117
+ n = config.sub_levels
118
+ self.convs = []
119
+ for i in range(n):
120
+ self.convs.append(
121
+ _conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i)))
122
+ self.convs.append(
123
+ _conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i + 1)))
124
+
125
+ def call(self, image: tf.Tensor, n: int) -> List[tf.Tensor]:
126
+ """Extracts a pyramid of features from the image.
127
+
128
+ Args:
129
+ image: tf.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
130
+ n: number of pyramid levels to extract. This can be less or equal to
131
+ options.sub_levels given in the __init__.
132
+ Returns:
133
+ The pyramid of features, starting from the finest level. Each element
134
+ contains the output after the last convolution on the corresponding
135
+ pyramid level.
136
+ """
137
+ head = image
138
+ pool = tf.keras.layers.AveragePooling2D(
139
+ pool_size=2, strides=2, padding='valid')
140
+ pyramid = []
141
+ for i in range(n):
142
+ head = self.convs[2*i](head)
143
+ head = self.convs[2*i+1](head)
144
+ pyramid.append(head)
145
+ if i < n-1:
146
+ head = pool(head)
147
+ return pyramid
148
+
149
+
150
+ class FeatureExtractor(tf.keras.layers.Layer):
151
+ """Extracts features from an image pyramid using a cascaded architecture.
152
+
153
+ Attributes:
154
+ name: Name of the layer
155
+ config: Options for the fusion_net frame interpolator
156
+ """
157
+
158
+ def __init__(self, name: str, config: options.Options):
159
+ super().__init__(name=name)
160
+ self.extract_sublevels = SubTreeExtractor('sub_extractor', config)
161
+ self.options = config
162
+
163
+ def call(self, image_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
164
+ """Extracts a cascaded feature pyramid.
165
+
166
+ Args:
167
+ image_pyramid: Image pyramid as a list, starting from the finest level.
168
+ Returns:
169
+ A pyramid of cascaded features.
170
+ """
171
+ sub_pyramids = []
172
+ for i in range(len(image_pyramid)):
173
+ # At each level of the image pyramid, creates a sub_pyramid of features
174
+ # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
175
+ # We use the same instance since we want to share the weights.
176
+ #
177
+ # However, we cap the depth of the sub_pyramid so we don't create features
178
+ # that are beyond the coarsest level of the cascaded feature pyramid we
179
+ # want to generate.
180
+ capped_sub_levels = min(len(image_pyramid) - i, self.options.sub_levels)
181
+ sub_pyramids.append(
182
+ self.extract_sublevels(image_pyramid[i], capped_sub_levels))
183
+ # Below we generate the cascades of features on each level of the feature
184
+ # pyramid. Assuming sub_levels=3, The layout of the features will be
185
+ # as shown in the example on file documentation above.
186
+ feature_pyramid = []
187
+ for i in range(len(image_pyramid)):
188
+ features = sub_pyramids[i][0]
189
+ for j in range(1, self.options.sub_levels):
190
+ if j <= i:
191
+ features = tf.concat([features, sub_pyramids[i - j][j]], axis=-1)
192
+ feature_pyramid.append(features)
193
+ return feature_pyramid
models/film_net/fusion.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """The final fusion stage for the film_net frame interpolator.
16
+
17
+ The inputs to this module are the warped input images, image features and
18
+ flow fields, all aligned to the target frame (often midway point between the
19
+ two original inputs). The output is the final image. FILM has no explicit
20
+ occlusion handling -- instead using the abovementioned information this module
21
+ automatically decides how to best blend the inputs together to produce content
22
+ in areas where the pixels can only be borrowed from one of the inputs.
23
+
24
+ Similarly, this module also decides on how much to blend in each input in case
25
+ of fractional timestep that is not at the halfway point. For example, if the two
26
+ inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
27
+ it often makes most sense to favor the first input. However, this is not
28
+ always the case -- in particular in occluded pixels.
29
+
30
+ The architecture of the Fusion module follows U-net [1] architecture's decoder
31
+ side, e.g. each pyramid level consists of concatenation with upsampled coarser
32
+ level output, and two 3x3 convolutions.
33
+
34
+ The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
35
+ upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
36
+ uses max-pooling which has a tendency to create checkerboard artifacts.
37
+
38
+ [1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
39
+ Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
40
+ [2] https://distill.pub/2016/deconv-checkerboard/
41
+ """
42
+
43
+ from typing import List
44
+
45
+ from . import options
46
+ import tensorflow as tf
47
+
48
+
49
+ def _relu(x: tf.Tensor) -> tf.Tensor:
50
+ return tf.nn.leaky_relu(x, alpha=0.2)
51
+
52
+
53
+ _NUMBER_OF_COLOR_CHANNELS = 3
54
+
55
+
56
+ class Fusion(tf.keras.layers.Layer):
57
+ """The decoder."""
58
+
59
+ def __init__(self, name: str, config: options.Options):
60
+ super().__init__(name=name)
61
+
62
+ # Each item 'convs[i]' will contain the list of convolutions to be applied
63
+ # for pyramid level 'i'.
64
+ self.convs: List[List[tf.keras.layers.Layer]] = []
65
+
66
+ # Store the levels, so we can verify right number of levels in call().
67
+ self.levels = config.fusion_pyramid_levels
68
+
69
+ # Create the convolutions. Roughly following the feature extractor, we
70
+ # double the number of filters when the resolution halves, but only up to
71
+ # the specialized_levels, after which we use the same number of filters on
72
+ # all levels.
73
+ #
74
+ # We create the convs in fine-to-coarse order, so that the array index
75
+ # for the convs will correspond to our normal indexing (0=finest level).
76
+ for i in range(config.fusion_pyramid_levels - 1):
77
+ m = config.specialized_levels
78
+ k = config.filters
79
+ num_filters = (k << i) if i < m else (k << m)
80
+
81
+ convs: List[tf.keras.layers.Layer] = []
82
+ convs.append(
83
+ tf.keras.layers.Conv2D(
84
+ filters=num_filters, kernel_size=[2, 2], padding='same'))
85
+ convs.append(
86
+ tf.keras.layers.Conv2D(
87
+ filters=num_filters,
88
+ kernel_size=[3, 3],
89
+ padding='same',
90
+ activation=_relu))
91
+ convs.append(
92
+ tf.keras.layers.Conv2D(
93
+ filters=num_filters,
94
+ kernel_size=[3, 3],
95
+ padding='same',
96
+ activation=_relu))
97
+ self.convs.append(convs)
98
+
99
+ # The final convolution that outputs RGB:
100
+ self.output_conv = tf.keras.layers.Conv2D(
101
+ filters=_NUMBER_OF_COLOR_CHANNELS, kernel_size=1)
102
+
103
+ def call(self, pyramid: List[tf.Tensor]) -> tf.Tensor:
104
+ """Runs the fusion module.
105
+
106
+ Args:
107
+ pyramid: The input feature pyramid as list of tensors. Each tensor being
108
+ in (B x H x W x C) format, with finest level tensor first.
109
+
110
+ Returns:
111
+ A batch of RGB images.
112
+ Raises:
113
+ ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
114
+ the constructor.
115
+ """
116
+ if len(pyramid) != self.levels:
117
+ raise ValueError(
118
+ 'Fusion called with different number of pyramid levels '
119
+ f'{len(pyramid)} than it was configured for, {self.levels}.')
120
+
121
+ # As a slight difference to a conventional decoder (e.g. U-net), we don't
122
+ # apply any extra convolutions to the coarsest level, but just pass it
123
+ # to finer levels for concatenation. This choice has not been thoroughly
124
+ # evaluated, but is motivated by the educated guess that the fusion part
125
+ # probably does not need large spatial context, because at this point the
126
+ # features are spatially aligned by the preceding warp.
127
+ net = pyramid[-1]
128
+
129
+ # Loop starting from the 2nd coarsest level:
130
+ for i in reversed(range(0, self.levels - 1)):
131
+ # Resize the tensor from coarser level to match for concatenation.
132
+ level_size = tf.shape(pyramid[i])[1:3]
133
+ net = tf.image.resize(net, level_size,
134
+ tf.image.ResizeMethod.NEAREST_NEIGHBOR)
135
+ net = self.convs[i][0](net)
136
+ net = tf.concat([pyramid[i], net], axis=-1)
137
+ net = self.convs[i][1](net)
138
+ net = self.convs[i][2](net)
139
+ net = self.output_conv(net)
140
+ return net
models/film_net/interpolator.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """The film_net frame interpolator main model code.
16
+
17
+ Basics
18
+ ======
19
+ The film_net is an end-to-end learned neural frame interpolator implemented as
20
+ a TF2 model. It has the following inputs and outputs:
21
+
22
+ Inputs:
23
+ x0: image A.
24
+ x1: image B.
25
+ time: desired sub-frame time.
26
+
27
+ Outputs:
28
+ image: the predicted in-between image at the chosen time in range [0, 1].
29
+
30
+ Additional outputs include forward and backward warped image pyramids, flow
31
+ pyramids, etc., that can be visualized for debugging and analysis.
32
+
33
+ Note that many training sets only contain triplets with ground truth at
34
+ time=0.5. If a model has been trained with such training set, it will only work
35
+ well for synthesizing frames at time=0.5. Such models can only generate more
36
+ in-between frames using recursion.
37
+
38
+ Architecture
39
+ ============
40
+ The inference consists of three main stages: 1) feature extraction 2) warping
41
+ 3) fusion. On high-level, the architecture has similarities to Context-aware
42
+ Synthesis for Video Frame Interpolation [1], but the exact architecture is
43
+ closer to Multi-view Image Fusion [2] with some modifications for the frame
44
+ interpolation use-case.
45
+
46
+ Feature extraction stage employs the cascaded multi-scale architecture described
47
+ in [2]. The advantage of this architecture is that coarse level flow prediction
48
+ can be learned from finer resolution image samples. This is especially useful
49
+ to avoid overfitting with moderately sized datasets.
50
+
51
+ The warping stage uses a residual flow prediction idea that is similar to
52
+ PWC-Net [3], Multi-view Image Fusion [2] and many others.
53
+
54
+ The fusion stage is similar to U-Net's decoder where the skip connections are
55
+ connected to warped image and feature pyramids. This is described in [2].
56
+
57
+ Implementation Conventions
58
+ ====================
59
+ Pyramids
60
+ --------
61
+ Throughtout the model, all image and feature pyramids are stored as python lists
62
+ with finest level first followed by downscaled versions obtained by successively
63
+ halving the resolution. The depths of all pyramids are determined by
64
+ options.pyramid_levels. The only exception to this is internal to the feature
65
+ extractor, where smaller feature pyramids are temporarily constructed with depth
66
+ options.sub_levels.
67
+
68
+ Color ranges & gamma
69
+ --------------------
70
+ The model code makes no assumptions on whether the images are in gamma or
71
+ linearized space or what is the range of RGB color values. So a model can be
72
+ trained with different choices. This does not mean that all the choices lead to
73
+ similar results. In practice the model has been proven to work well with RGB
74
+ scale = [0,1] with gamma-space images (i.e. not linearized).
75
+
76
+ [1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018
77
+ [2] Multi-view Image Fusion, Trinidad et al, 2019
78
+ [3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
79
+ """
80
+
81
+ from . import feature_extractor
82
+ from . import fusion
83
+ from . import options
84
+ from . import pyramid_flow_estimator
85
+ from . import util
86
+ import tensorflow as tf
87
+
88
+
89
+ def create_model(x0: tf.Tensor, x1: tf.Tensor, time: tf.Tensor,
90
+ config: options.Options) -> tf.keras.Model:
91
+ """Creates a frame interpolator model.
92
+
93
+ The frame interpolator is used to warp the two images to the in-between frame
94
+ at given time. Note that training data is often restricted such that
95
+ supervision only exists at 'time'=0.5. If trained with such data, the model
96
+ will overfit to predicting images that are halfway between the two inputs and
97
+ will not be as accurate elsewhere.
98
+
99
+ Args:
100
+ x0: first input image as BxHxWxC tensor.
101
+ x1: second input image as BxHxWxC tensor.
102
+ time: ignored by film_net. We always infer a frame at t = 0.5.
103
+ config: FilmNetOptions object.
104
+
105
+ Returns:
106
+ A tf.Model that takes 'x0', 'x1', and 'time' as input and returns a
107
+ dictionary with the interpolated result in 'image'. For additional
108
+ diagnostics or supervision, the following intermediate results are
109
+ also stored in the dictionary:
110
+ 'x0_warped': an intermediate result obtained by warping from x0
111
+ 'x1_warped': an intermediate result obtained by warping from x1
112
+ 'forward_residual_flow_pyramid': pyramid with forward residual flows
113
+ 'backward_residual_flow_pyramid': pyramid with backward residual flows
114
+ 'forward_flow_pyramid': pyramid with forward flows
115
+ 'backward_flow_pyramid': pyramid with backward flows
116
+
117
+ Raises:
118
+ ValueError, if config.pyramid_levels < config.fusion_pyramid_levels.
119
+ """
120
+ if config.pyramid_levels < config.fusion_pyramid_levels:
121
+ raise ValueError('config.pyramid_levels must be greater than or equal to '
122
+ 'config.fusion_pyramid_levels.')
123
+
124
+ x0_decoded = x0
125
+ x1_decoded = x1
126
+
127
+ # shuffle images
128
+ image_pyramids = [
129
+ util.build_image_pyramid(x0_decoded, config),
130
+ util.build_image_pyramid(x1_decoded, config)
131
+ ]
132
+
133
+ # Siamese feature pyramids:
134
+ extract = feature_extractor.FeatureExtractor('feat_net', config)
135
+ feature_pyramids = [extract(image_pyramids[0]), extract(image_pyramids[1])]
136
+
137
+ predict_flow = pyramid_flow_estimator.PyramidFlowEstimator(
138
+ 'predict_flow', config)
139
+
140
+ # Predict forward flow.
141
+ forward_residual_flow_pyramid = predict_flow(feature_pyramids[0],
142
+ feature_pyramids[1])
143
+ # Predict backward flow.
144
+ backward_residual_flow_pyramid = predict_flow(feature_pyramids[1],
145
+ feature_pyramids[0])
146
+
147
+ # Concatenate features and images:
148
+
149
+ # Note that we keep up to 'fusion_pyramid_levels' levels as only those
150
+ # are used by the fusion module.
151
+ fusion_pyramid_levels = config.fusion_pyramid_levels
152
+
153
+ forward_flow_pyramid = util.flow_pyramid_synthesis(
154
+ forward_residual_flow_pyramid)[:fusion_pyramid_levels]
155
+ backward_flow_pyramid = util.flow_pyramid_synthesis(
156
+ backward_residual_flow_pyramid)[:fusion_pyramid_levels]
157
+
158
+ # We multiply the flows with t and 1-t to warp to the desired fractional time.
159
+ #
160
+ # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
161
+ # lator for multi-frame interpolation. Below, we create a constant tensor of
162
+ # shape [B]. We use the `time` tensor to infer the batch size.
163
+ mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(time)
164
+ backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
165
+ forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
166
+
167
+ pyramids_to_warp = [
168
+ util.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels],
169
+ feature_pyramids[0][:fusion_pyramid_levels]),
170
+ util.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels],
171
+ feature_pyramids[1][:fusion_pyramid_levels])
172
+ ]
173
+
174
+ # Warp features and images using the flow. Note that we use backward warping
175
+ # and backward flow is used to read from image 0 and forward flow from
176
+ # image 1.
177
+ forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
178
+ backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
179
+
180
+ aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
181
+ backward_warped_pyramid)
182
+ aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
183
+ aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
184
+
185
+ fuse = fusion.Fusion('fusion', config)
186
+ prediction = fuse(aligned_pyramid)
187
+
188
+ output_color = prediction[..., :3]
189
+ outputs = {'image': output_color}
190
+
191
+ if config.use_aux_outputs:
192
+ outputs.update({
193
+ 'x0_warped': forward_warped_pyramid[0][..., 0:3],
194
+ 'x1_warped': backward_warped_pyramid[0][..., 0:3],
195
+ 'forward_residual_flow_pyramid': forward_residual_flow_pyramid,
196
+ 'backward_residual_flow_pyramid': backward_residual_flow_pyramid,
197
+ 'forward_flow_pyramid': forward_flow_pyramid,
198
+ 'backward_flow_pyramid': backward_flow_pyramid,
199
+ })
200
+
201
+ model = tf.keras.Model(
202
+ inputs={
203
+ 'x0': x0,
204
+ 'x1': x1,
205
+ 'time': time
206
+ }, outputs=outputs)
207
+ return model
models/film_net/options.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Options for the film_net video frame interpolator."""
16
+
17
+ import gin.tf
18
+
19
+
20
+ @gin.configurable('film_net')
21
+ class Options(object):
22
+ """Options for the film_net video frame interpolator.
23
+
24
+ To further understand these options, see the paper here:
25
+ https://augmentedperception.github.io/pixelfusion/.
26
+
27
+ The default values are suitable for up to 64 pixel motions. For larger motions
28
+ the number of flow convolutions and/or pyramid levels can be increased, but
29
+ usually with the cost of accuracy on solving the smaller motions.
30
+
31
+ The maximum motion in pixels that the system can resolve is equivalent to
32
+ 2^(pyramid_levels-1) * flow_convs[-1]. I.e. the downsampling factor times
33
+ the receptive field radius on the coarsest pyramid level. This, of course,
34
+ assumes that the training data contains such motions.
35
+
36
+ Note that to avoid a run-time error, the input image width and height have to
37
+ be divisible by 2^(pyramid_levels-1).
38
+
39
+ Attributes:
40
+ pyramid_levels: How many pyramid levels to use for the feature pyramid and
41
+ the flow prediction.
42
+ fusion_pyramid_levels: How many pyramid levels to use for the fusion module
43
+ this must be less or equal to 'pyramid_levels'.
44
+ specialized_levels: How many fine levels of the pyramid shouldn't share the
45
+ weights. If specialized_levels = 3, it means that two finest levels are
46
+ independently learned, whereas the third will be learned together with the
47
+ rest of the pyramid. Valid range [1, pyramid_levels].
48
+ flow_convs: Convolutions per residual flow predictor. This array should have
49
+ specialized_levels+1 items on it, the last item representing the number of
50
+ convs used by any pyramid level that uses shared weights.
51
+ flow_filters: Base number of filters in residual flow predictors. This array
52
+ should have specialized_levels+1 items on it, the last item representing
53
+ the number of filters used by any pyramid level that uses shared weights.
54
+ sub_levels: The depth of the cascaded feature tree each pyramid level
55
+ concatenates together to compute the flow. This must be within range [1,
56
+ specialized_level+1]. It is recommended to set this to specialized_levels
57
+ + 1
58
+ filters: Base number of features to extract. On each pyramid level the
59
+ number doubles. This is used by both feature extraction and fusion stages.
60
+ use_aux_outputs: Set to True to include auxiliary outputs along with the
61
+ predicted image.
62
+ """
63
+
64
+ def __init__(self,
65
+ pyramid_levels=5,
66
+ fusion_pyramid_levels=5,
67
+ specialized_levels=3,
68
+ flow_convs=None,
69
+ flow_filters=None,
70
+ sub_levels=4,
71
+ filters=16,
72
+ use_aux_outputs=True):
73
+ self.pyramid_levels = pyramid_levels
74
+ self.fusion_pyramid_levels = fusion_pyramid_levels
75
+ self.specialized_levels = specialized_levels
76
+ self.flow_convs = flow_convs or [4, 4, 4, 4]
77
+ self.flow_filters = flow_filters or [64, 128, 256, 256]
78
+ self.sub_levels = sub_levels
79
+ self.filters = filters
80
+ self.use_aux_outputs = use_aux_outputs
81
+
models/film_net/pyramid_flow_estimator.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """TF2 layer for estimating optical flow by a residual flow pyramid.
16
+
17
+ This approach of estimating optical flow between two images can be traced back
18
+ to [1], but is also used by later neural optical flow computation methods such
19
+ as SpyNet [2] and PWC-Net [3].
20
+
21
+ The basic idea is that the optical flow is first estimated in a coarse
22
+ resolution, then the flow is upsampled to warp the higher resolution image and
23
+ then a residual correction is computed and added to the estimated flow. This
24
+ process is repeated in a pyramid on coarse to fine order to successively
25
+ increase the resolution of both optical flow and the warped image.
26
+
27
+ In here, the optical flow predictor is used as an internal component for the
28
+ film_net frame interpolator, to warp the two input images into the inbetween,
29
+ target frame.
30
+
31
+ [1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987.
32
+ [2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid
33
+ Network. 2016
34
+ [3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using
35
+ Pyramid, Warping, and Cost Volume, 2017
36
+ """
37
+
38
+ from typing import List
39
+
40
+ from . import options
41
+ from . import util
42
+ import tensorflow as tf
43
+
44
+
45
+ def _relu(x: tf.Tensor) -> tf.Tensor:
46
+ return tf.nn.leaky_relu(x, alpha=0.2)
47
+
48
+
49
+ class FlowEstimator(tf.keras.layers.Layer):
50
+ """Small-receptive field predictor for computing the flow between two images.
51
+
52
+ This is used to compute the residual flow fields in PyramidFlowEstimator.
53
+
54
+ Note that while the number of 3x3 convolutions & filters to apply is
55
+ configurable, two extra 1x1 convolutions are appended to extract the flow in
56
+ the end.
57
+
58
+ Attributes:
59
+ name: The name of the layer
60
+ num_convs: Number of 3x3 convolutions to apply
61
+ num_filters: Number of filters in each 3x3 convolution
62
+ """
63
+
64
+ def __init__(self, name: str, num_convs: int, num_filters: int):
65
+ super(FlowEstimator, self).__init__(name=name)
66
+ def conv(filters, size, name, activation=_relu):
67
+ return tf.keras.layers.Conv2D(
68
+ name=name,
69
+ filters=filters,
70
+ kernel_size=size,
71
+ padding='same',
72
+ activation=activation)
73
+
74
+ self._convs = []
75
+ for i in range(num_convs):
76
+ self._convs.append(conv(filters=num_filters, size=3, name=f'conv_{i}'))
77
+ self._convs.append(conv(filters=num_filters/2, size=1, name=f'conv_{i+1}'))
78
+ # For the final convolution, we want no activation at all to predict the
79
+ # optical flow vector values. We have done extensive testing on explicitly
80
+ # bounding these values using sigmoid, but it turned out that having no
81
+ # activation gives better results.
82
+ self._convs.append(
83
+ conv(filters=2, size=1, name=f'conv_{i+2}', activation=None))
84
+
85
+ def call(self, features_a: tf.Tensor, features_b: tf.Tensor) -> tf.Tensor:
86
+ """Estimates optical flow between two images.
87
+
88
+ Args:
89
+ features_a: per pixel feature vectors for image A (B x H x W x C)
90
+ features_b: per pixel feature vectors for image B (B x H x W x C)
91
+
92
+ Returns:
93
+ A tensor with optical flow from A to B
94
+ """
95
+ net = tf.concat([features_a, features_b], axis=-1)
96
+ for conv in self._convs:
97
+ net = conv(net)
98
+ return net
99
+
100
+
101
+ class PyramidFlowEstimator(tf.keras.layers.Layer):
102
+ """Predicts optical flow by coarse-to-fine refinement.
103
+
104
+ Attributes:
105
+ name: The name of the layer
106
+ config: Options for the film_net frame interpolator
107
+ """
108
+
109
+ def __init__(self, name: str, config: options.Options):
110
+ super(PyramidFlowEstimator, self).__init__(name=name)
111
+ self._predictors = []
112
+ for i in range(config.specialized_levels):
113
+ self._predictors.append(
114
+ FlowEstimator(
115
+ name=f'flow_predictor_{i}',
116
+ num_convs=config.flow_convs[i],
117
+ num_filters=config.flow_filters[i]))
118
+ shared_predictor = FlowEstimator(
119
+ name='flow_predictor_shared',
120
+ num_convs=config.flow_convs[-1],
121
+ num_filters=config.flow_filters[-1])
122
+ for i in range(config.specialized_levels, config.pyramid_levels):
123
+ self._predictors.append(shared_predictor)
124
+
125
+ def call(self, feature_pyramid_a: List[tf.Tensor],
126
+ feature_pyramid_b: List[tf.Tensor]) -> List[tf.Tensor]:
127
+ """Estimates residual flow pyramids between two image pyramids.
128
+
129
+ Each image pyramid is represented as a list of tensors in fine-to-coarse
130
+ order. Each individual image is represented as a tensor where each pixel is
131
+ a vector of image features.
132
+
133
+ util.flow_pyramid_synthesis can be used to convert the residual flow
134
+ pyramid returned by this method into a flow pyramid, where each level
135
+ encodes the flow instead of a residual correction.
136
+
137
+ Args:
138
+ feature_pyramid_a: image pyramid as a list in fine-to-coarse order
139
+ feature_pyramid_b: image pyramid as a list in fine-to-coarse order
140
+
141
+ Returns:
142
+ List of flow tensors, in fine-to-coarse order, each level encoding the
143
+ difference against the bilinearly upsampled version from the coarser
144
+ level. The coarsest flow tensor, e.g. the last element in the array is the
145
+ 'DC-term', e.g. not a residual (alternatively you can think of it being a
146
+ residual against zero).
147
+ """
148
+ levels = len(feature_pyramid_a)
149
+ v = self._predictors[-1](feature_pyramid_a[-1], feature_pyramid_b[-1])
150
+ residuals = [v]
151
+ for i in reversed(range(0, levels-1)):
152
+ # Upsamples the flow to match the current pyramid level. Also, scales the
153
+ # magnitude by two to reflect the new size.
154
+ level_size = tf.shape(feature_pyramid_a[i])[1:3]
155
+ v = tf.image.resize(images=2*v, size=level_size)
156
+ # Warp feature_pyramid_b[i] image based on the current flow estimate.
157
+ warped = util.warp(feature_pyramid_b[i], v)
158
+ # Estimate the residual flow between pyramid_a[i] and warped image:
159
+ v_residual = self._predictors[i](feature_pyramid_a[i], warped)
160
+ residuals.append(v_residual)
161
+ v = v_residual + v
162
+ # Use reversed() to return in the 'standard' finest-first-order:
163
+ return list(reversed(residuals))
models/film_net/util.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Various utilities used in the film_net frame interpolator model."""
16
+ from typing import List
17
+
18
+ from .options import Options
19
+ import tensorflow as tf
20
+ import tensorflow_addons.image as tfa_image
21
+
22
+
23
+ def build_image_pyramid(image: tf.Tensor,
24
+ options: Options) -> List[tf.Tensor]:
25
+ """Builds an image pyramid from a given image.
26
+
27
+ The original image is included in the pyramid and the rest are generated by
28
+ successively halving the resolution.
29
+
30
+ Args:
31
+ image: the input image.
32
+ options: film_net options object
33
+
34
+ Returns:
35
+ A list of images starting from the finest with options.pyramid_levels items
36
+ """
37
+ levels = options.pyramid_levels
38
+ pyramid = []
39
+ pool = tf.keras.layers.AveragePooling2D(
40
+ pool_size=2, strides=2, padding='valid')
41
+ for i in range(0, levels):
42
+ pyramid.append(image)
43
+ if i < levels-1:
44
+ image = pool(image)
45
+ return pyramid
46
+
47
+
48
+ def warp(image: tf.Tensor, flow: tf.Tensor) -> tf.Tensor:
49
+ """Backward warps the image using the given flow.
50
+
51
+ Specifically, the output pixel in batch b, at position x, y will be computed
52
+ as follows:
53
+ (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
54
+ output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
55
+
56
+ Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
57
+ y in position 1.
58
+
59
+ Args:
60
+ image: An image with shape BxHxWxC.
61
+ flow: A flow with shape BxHxWx2, with the two channels denoting the relative
62
+ offset in order: (dx, dy).
63
+ Returns:
64
+ A warped image.
65
+ """
66
+ # tfa_image.dense_image_warp expects unconventional negated optical flow, so
67
+ # negate the flow here. Also revert x and y for compatibility with older saved
68
+ # models trained with custom warp op that stored (x, y) instead of (y, x) flow
69
+ # vectors.
70
+ flow = -flow[..., ::-1]
71
+
72
+ # Note: we have to wrap tfa_image.dense_image_warp into a Keras Lambda,
73
+ # because it is not compatible with Keras symbolic tensors and we want to use
74
+ # this code as part of a Keras model. Wrapping it into a lambda has the
75
+ # consequence that tfa_image.dense_image_warp is only called once the tensors
76
+ # are concrete, e.g. actually contain data. The inner lambda is a workaround
77
+ # for passing two parameters, e.g you would really want to write:
78
+ # tf.keras.layers.Lambda(tfa_image.dense_image_warp)(image, flow), but this is
79
+ # not supported by the Keras Lambda.
80
+ warped = tf.keras.layers.Lambda(
81
+ lambda x: tfa_image.dense_image_warp(*x))((image, flow))
82
+ return tf.reshape(warped, shape=tf.shape(image))
83
+
84
+
85
+ def multiply_pyramid(pyramid: List[tf.Tensor],
86
+ scalar: tf.Tensor) -> List[tf.Tensor]:
87
+ """Multiplies all image batches in the pyramid by a batch of scalars.
88
+
89
+ Args:
90
+ pyramid: Pyramid of image batches.
91
+ scalar: Batch of scalars.
92
+
93
+ Returns:
94
+ An image pyramid with all images multiplied by the scalar.
95
+ """
96
+ # To multiply each image with its corresponding scalar, we first transpose
97
+ # the batch of images from BxHxWxC-format to CxHxWxB. This can then be
98
+ # multiplied with a batch of scalars, then we transpose back to the standard
99
+ # BxHxWxC form.
100
+ return [
101
+ tf.transpose(tf.transpose(image, [3, 1, 2, 0]) * scalar, [3, 1, 2, 0])
102
+ for image in pyramid
103
+ ]
104
+
105
+
106
+ def flow_pyramid_synthesis(
107
+ residual_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
108
+ """Converts a residual flow pyramid into a flow pyramid."""
109
+ flow = residual_pyramid[-1]
110
+ flow_pyramid = [flow]
111
+ for residual_flow in reversed(residual_pyramid[:-1]):
112
+ level_size = tf.shape(residual_flow)[1:3]
113
+ flow = tf.image.resize(images=2*flow, size=level_size)
114
+ flow = residual_flow + flow
115
+ flow_pyramid.append(flow)
116
+ # Use reversed() to return in the 'standard' finest-first-order:
117
+ return list(reversed(flow_pyramid))
118
+
119
+
120
+ def pyramid_warp(feature_pyramid: List[tf.Tensor],
121
+ flow_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
122
+ """Warps the feature pyramid using the flow pyramid.
123
+
124
+ Args:
125
+ feature_pyramid: feature pyramid starting from the finest level.
126
+ flow_pyramid: flow fields, starting from the finest level.
127
+
128
+ Returns:
129
+ Reverse warped feature pyramid.
130
+ """
131
+ warped_feature_pyramid = []
132
+ for features, flow in zip(feature_pyramid, flow_pyramid):
133
+ warped_feature_pyramid.append(warp(features, flow))
134
+ return warped_feature_pyramid
135
+
136
+
137
+ def concatenate_pyramids(pyramid1: List[tf.Tensor],
138
+ pyramid2: List[tf.Tensor]) -> List[tf.Tensor]:
139
+ """Concatenates each pyramid level together in the channel dimension."""
140
+ result = []
141
+ for features1, features2 in zip(pyramid1, pyramid2):
142
+ result.append(tf.concat([features1, features2], axis=-1))
143
+ return result
moment.gif ADDED

Git LFS Details

  • SHA256: e2624128b6b5ed9c8093a7cfdbc36a815d56ddd62987d03d8476dacb23ad4f2e
  • Pointer size: 133 Bytes
  • Size of remote file: 21.4 MB
photos/one.png ADDED

Git LFS Details

  • SHA256: 8bad1c97feb31a4bec60a809f808e1b0a26f55219fa991c4caa2e696bce8e81f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.44 MB
photos/two.png ADDED

Git LFS Details

  • SHA256: d80058cede12e10b9d7fe49ea022d1cc4f9c28bd2a00a1c3d4830d048c55f3fa
  • Pointer size: 132 Bytes
  • Size of remote file: 3.39 MB
predict.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import tempfile
5
+ import tensorflow as tf
6
+ import mediapy
7
+ from PIL import Image
8
+ import cog
9
+
10
+ from eval import interpolator, util
11
+
12
+ _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
13
+
14
+
15
+ class Predictor(cog.Predictor):
16
+ def setup(self):
17
+ import tensorflow as tf
18
+ print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
19
+ self.interpolator = interpolator.Interpolator("pretrained_models/film_net/Style/saved_model", None)
20
+
21
+ # Batched time.
22
+ self.batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
23
+
24
+ @cog.input(
25
+ "frame1",
26
+ type=Path,
27
+ help="The first input frame",
28
+ )
29
+ @cog.input(
30
+ "frame2",
31
+ type=Path,
32
+ help="The second input frame",
33
+ )
34
+ @cog.input(
35
+ "times_to_interpolate",
36
+ type=int,
37
+ default=1,
38
+ min=1,
39
+ max=8,
40
+ help="Controls the number of times the frame interpolator is invoked If set to 1, the output will be the "
41
+ "sub-frame at t=0.5; when set to > 1, the output will be the interpolation video with "
42
+ "(2^times_to_interpolate + 1) frames, fps of 30.",
43
+ )
44
+ def predict(self, frame1, frame2, times_to_interpolate):
45
+ INPUT_EXT = ['.png', '.jpg', '.jpeg']
46
+ assert os.path.splitext(str(frame1))[-1] in INPUT_EXT and os.path.splitext(str(frame2))[-1] in INPUT_EXT, \
47
+ "Please provide png, jpg or jpeg images."
48
+
49
+ # make sure 2 images are the same size
50
+ img1 = Image.open(str(frame1))
51
+ img2 = Image.open(str(frame2))
52
+ if not img1.size == img2.size:
53
+ img1 = img1.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1])))
54
+ img2 = img2.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1])))
55
+ frame1 = 'new_frame1.png'
56
+ frame2 = 'new_frame2.png'
57
+ img1.save(frame1)
58
+ img2.save(frame2)
59
+
60
+ if times_to_interpolate == 1:
61
+ # First batched image.
62
+ image_1 = util.read_image(str(frame1))
63
+ image_batch_1 = np.expand_dims(image_1, axis=0)
64
+
65
+ # Second batched image.
66
+ image_2 = util.read_image(str(frame2))
67
+ image_batch_2 = np.expand_dims(image_2, axis=0)
68
+
69
+ # Invoke the model once.
70
+
71
+ mid_frame = self.interpolator.interpolate(image_batch_1, image_batch_2, self.batch_dt)[0]
72
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
73
+ util.write_image(str(out_path), mid_frame)
74
+ return out_path
75
+
76
+
77
+ input_frames = [str(frame1), str(frame2)]
78
+
79
+ frames = list(
80
+ util.interpolate_recursively_from_files(
81
+ input_frames, times_to_interpolate, self.interpolator))
82
+ print('Interpolated frames generated, saving now as output video.')
83
+
84
+ ffmpeg_path = util.get_ffmpeg_path()
85
+ mediapy.set_ffmpeg(ffmpeg_path)
86
+ out_path = Path(tempfile.mkdtemp()) / "out.mp4"
87
+ mediapy.write_video(str(out_path), frames, fps=30)
88
+ return out_path
requirements.txt CHANGED
@@ -1,4 +1,14 @@
1
- shiny==0.10.2
2
- shinyswatch==0.6.1
3
- seaborn==0.12.2
4
- matplotlib==3.7.1
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker base image: `gcr.io/deeplearning-platform-release/tf2-gpu.2-6:latest`
2
+ tensorflow==2.6.2 # The latest should include tensorflow-gpu
3
+ tensorflow-datasets==4.4.0
4
+ tensorflow-addons==0.15.0
5
+ absl-py==0.12.0
6
+ gin-config==0.5.0
7
+ parameterized==0.8.1
8
+ mediapy==1.0.3
9
+ scikit-image==0.19.1
10
+ apache-beam==2.34.0
11
+ google-cloud-bigquery-storage==1.1.0 # Suppresses a harmless error from beam
12
+ natsort==8.1.0
13
+ gdown==4.5.4
14
+ tqdm==4.64.1
training/.DS_Store ADDED
Binary file (6.15 kB). View file
 
training/augmentation_lib.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Dataset augmentation for frame interpolation."""
16
+ from typing import Callable, Dict, List
17
+
18
+ import gin.tf
19
+ import numpy as np
20
+ import tensorflow as tf
21
+ import tensorflow.math as tfm
22
+ import tensorflow_addons.image as tfa_image
23
+
24
+ _PI = 3.141592653589793
25
+
26
+
27
+ def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
28
+ r"""Rotate the (u,v) vector of each pixel with angle in radians.
29
+
30
+ Flow matrix system of coordinates.
31
+ . . . . u (x)
32
+ .
33
+ .
34
+ . v (-y)
35
+
36
+ Rotation system of coordinates.
37
+ . y
38
+ .
39
+ .
40
+ . . . . x
41
+ Args:
42
+ flow: Flow map which has been image-rotated.
43
+ angle_rad: The rotation angle in radians.
44
+
45
+ Returns:
46
+ A flow with the same map but each (u,v) vector rotated by angle_rad.
47
+ """
48
+ u, v = tf.split(flow, 2, axis=-1)
49
+ # rotu = u * cos(angle) - (-v) * sin(angle)
50
+ rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v
51
+ # rotv = -(u * sin(theta) + (-v) * cos(theta))
52
+ rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v
53
+ return tf.concat((rot_u, rot_v), axis=-1)
54
+
55
+
56
+ def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor:
57
+ """Rotates a flow by a multiple of 90 degrees.
58
+
59
+ Args:
60
+ flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
61
+ k: The multiplier factor.
62
+
63
+ Returns:
64
+ A flow image of the same shape as the input rotated by multiples of 90
65
+ degrees.
66
+ """
67
+ angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.)
68
+ flow = tf.image.rot90(flow, k)
69
+ return _rotate_flow_vectors(flow, angle_rad)
70
+
71
+
72
+ def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
73
+ """Rotates a flow by a the provided angle in radians.
74
+
75
+ Args:
76
+ flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
77
+ angle_rad: The angle to ratate the flow in radians.
78
+
79
+ Returns:
80
+ A flow image of the same shape as the input rotated by the provided angle in
81
+ radians.
82
+ """
83
+ flow = tfa_image.rotate(
84
+ flow,
85
+ angles=angle_rad,
86
+ interpolation='bilinear',
87
+ fill_mode='reflect')
88
+ return _rotate_flow_vectors(flow, angle_rad)
89
+
90
+
91
+ def flow_flip(flow: tf.Tensor) -> tf.Tensor:
92
+ """Flips a flow left to right.
93
+
94
+ Args:
95
+ flow: The flow image shaped (H, W, 2) to flip left to right.
96
+
97
+ Returns:
98
+ A flow image of the same shape as the input flipped left to right.
99
+ """
100
+ flow = tf.image.flip_left_right(tf.identity(flow))
101
+ flow_u, flow_v = tf.split(flow, 2, axis=-1)
102
+ return tf.stack([-1 * flow_u, flow_v], axis=-1)
103
+
104
+
105
+ def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
106
+ """Rotates a stack of images by a random multiples of 90 degrees.
107
+
108
+ Args:
109
+ images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
110
+ channel's axis.
111
+ Returns:
112
+ A tf.Tensor of the same rank as the `images` after random rotation by
113
+ multiples of 90 degrees applied counter-clock wise.
114
+ """
115
+ random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
116
+ for key in images:
117
+ images[key] = tf.image.rot90(images[key], k=random_k)
118
+ return images
119
+
120
+
121
+ def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
122
+ """Flips a stack of images randomly.
123
+
124
+ Args:
125
+ images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
126
+ channel's axis.
127
+
128
+ Returns:
129
+ A tf.Tensor of the images after random left to right flip.
130
+ """
131
+ prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
132
+ prob = tf.cast(prob, tf.bool)
133
+
134
+ def _identity(image):
135
+ return image
136
+
137
+ def _flip_left_right(image):
138
+ return tf.image.flip_left_right(image)
139
+
140
+ # pylint: disable=cell-var-from-loop
141
+ for key in images:
142
+ images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]),
143
+ lambda: _identity(images[key]))
144
+ return images
145
+
146
+
147
+ def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
148
+ """Reverses a stack of images randomly.
149
+
150
+ Args:
151
+ images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with
152
+ each tensor being a stack of iamges along the last channel axis.
153
+
154
+ Returns:
155
+ A dictionary of tf.Tensors, each shaped the same as the input images dict.
156
+ """
157
+ prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
158
+ prob = tf.cast(prob, tf.bool)
159
+
160
+ def _identity(images):
161
+ return images
162
+
163
+ def _reverse(images):
164
+ images['x0'], images['x1'] = images['x1'], images['x0']
165
+ return images
166
+
167
+ return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images))
168
+
169
+
170
+ def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
171
+ """Rotates image randomly with [-45 to 45 degrees].
172
+
173
+ Args:
174
+ images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
175
+ channel's axis.
176
+
177
+ Returns:
178
+ A tf.Tensor of the images after random rotation with a bound of -72 to 72
179
+ degrees.
180
+ """
181
+ prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
182
+ prob = tf.cast(prob, tf.float32)
183
+ random_angle = tf.random.uniform((),
184
+ minval=-0.25 * np.pi,
185
+ maxval=0.25 * np.pi,
186
+ dtype=tf.float32)
187
+
188
+ for key in images:
189
+ images[key] = tfa_image.rotate(
190
+ images[key],
191
+ angles=random_angle * prob,
192
+ interpolation='bilinear',
193
+ fill_mode='constant')
194
+ return images
195
+
196
+
197
+ @gin.configurable('data_augmentation')
198
+ def data_augmentations(
199
+ names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]:
200
+ """Creates the data augmentation functions.
201
+
202
+ Args:
203
+ names: The list of augmentation function names.
204
+ Returns:
205
+ A dictionary of Callables to the augmentation functions, keyed by their
206
+ names.
207
+ """
208
+ augmentations = dict()
209
+ for name in names:
210
+ if name == 'random_image_rot90':
211
+ augmentations[name] = random_image_rot90
212
+ elif name == 'random_rotate':
213
+ augmentations[name] = random_rotate
214
+ elif name == 'random_flip':
215
+ augmentations[name] = random_flip
216
+ elif name == 'random_reverse':
217
+ augmentations[name] = random_reverse
218
+ else:
219
+ raise AttributeError('Invalid augmentation function %s' % name)
220
+ return augmentations
training/build_saved_model_cli.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Converts TF2 training checkpoint to a saved model.
16
+
17
+ The model must match the checkpoint, so the gin config must be given.
18
+
19
+ Usage example:
20
+ python3 -m frame_interpolation.training.build_saved_model_cli \
21
+ --gin_config <filepath of the gin config the training session was based> \
22
+ --base_folder <base folder of training sessions> \
23
+ --label <the name of the run>
24
+
25
+ This will produce a saved model into: <base_folder>/<label>/saved_model
26
+ """
27
+ import os
28
+ from typing import Sequence
29
+
30
+ from . import model_lib
31
+ from absl import app
32
+ from absl import flags
33
+ from absl import logging
34
+ import gin.tf
35
+ import tensorflow as tf
36
+ tf.get_logger().setLevel('ERROR')
37
+
38
+ _GIN_CONFIG = flags.DEFINE_string(
39
+ name='gin_config',
40
+ default='config.gin',
41
+ help='Gin config file, saved in the training session <root folder>.')
42
+ _LABEL = flags.DEFINE_string(
43
+ name='label',
44
+ default=None,
45
+ required=True,
46
+ help='Descriptive label for the training session.')
47
+ _BASE_FOLDER = flags.DEFINE_string(
48
+ name='base_folder',
49
+ default=None,
50
+ help='Path to all training sessions.')
51
+ _MODE = flags.DEFINE_enum(
52
+ name='mode',
53
+ default=None,
54
+ enum_values=['cpu', 'gpu', 'tpu'],
55
+ help='Distributed strategy approach.')
56
+
57
+
58
+ def _build_saved_model(checkpoint_path: str, config_files: Sequence[str],
59
+ output_model_path: str):
60
+ """Builds a saved model based on the checkpoint directory."""
61
+ gin.parse_config_files_and_bindings(
62
+ config_files=config_files,
63
+ bindings=None,
64
+ skip_unknown=True)
65
+ model = model_lib.create_model()
66
+ checkpoint = tf.train.Checkpoint(model=model)
67
+ checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)
68
+ try:
69
+ logging.info('Restoring from %s', checkpoint_file)
70
+ status = checkpoint.restore(checkpoint_file)
71
+ status.assert_existing_objects_matched()
72
+ status.expect_partial()
73
+ model.save(output_model_path)
74
+ except (tf.errors.NotFoundError, AssertionError) as err:
75
+ logging.info('Failed to restore checkpoint from %s. Error:\n%s',
76
+ checkpoint_file, err)
77
+
78
+
79
+ def main(argv):
80
+ if len(argv) > 1:
81
+ raise app.UsageError('Too many command-line arguments.')
82
+
83
+ checkpoint_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train')
84
+ if not tf.io.gfile.exists(_GIN_CONFIG.value):
85
+ config_file = os.path.join(_BASE_FOLDER.value, _LABEL.value,
86
+ _GIN_CONFIG.value)
87
+ else:
88
+ config_file = _GIN_CONFIG.value
89
+ output_model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value,
90
+ 'saved_model')
91
+ _build_saved_model(
92
+ checkpoint_path=checkpoint_path,
93
+ config_files=[config_file],
94
+ output_model_path=output_model_path)
95
+ logging.info('The saved model stored into %s/.', output_model_path)
96
+
97
+ if __name__ == '__main__':
98
+ app.run(main)
training/config/film_net-L1.gin ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ model.name = 'film_net'
16
+
17
+ film_net.pyramid_levels = 7
18
+ film_net.fusion_pyramid_levels = 5
19
+ film_net.specialized_levels = 3
20
+ film_net.sub_levels = 4
21
+ film_net.flow_convs = [3, 3, 3, 3]
22
+ film_net.flow_filters = [32, 64, 128, 256]
23
+ film_net.filters = 64
24
+
25
+ training.learning_rate = 0.0001
26
+ training.learning_rate_decay_steps = 750000
27
+ training.learning_rate_decay_rate = 0.464158
28
+ training.learning_rate_staircase = True
29
+ training.num_steps = 3000000
30
+
31
+ # in the sweep
32
+ training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
33
+ training_dataset.batch_size = 8
34
+ training_dataset.crop_size = 256
35
+
36
+ eval_datasets.batch_size = 1
37
+ eval_datasets.max_examples = -1
38
+ # eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
39
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
40
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
41
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
42
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
43
+ # eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
44
+ eval_datasets.files = []
45
+ eval_datasets.names = []
46
+
47
+ # Training augmentation (in addition to random crop)
48
+ data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
49
+
50
+ # Loss functions
51
+ training_losses.loss_names = ['l1']
52
+ training_losses.loss_weights = [1.0]
53
+
54
+ test_losses.loss_names = ['l1', 'psnr', 'ssim']
55
+ test_losses.loss_weights = [1.0, 1.0, 1.0]
training/config/film_net-Style.gin ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ model.name = 'film_net'
16
+
17
+ film_net.pyramid_levels = 7
18
+ film_net.fusion_pyramid_levels = 5
19
+ film_net.specialized_levels = 3
20
+ film_net.sub_levels = 4
21
+ film_net.flow_convs = [3, 3, 3, 3]
22
+ film_net.flow_filters = [32, 64, 128, 256]
23
+ film_net.filters = 64
24
+
25
+ training.learning_rate = 0.0001
26
+ training.learning_rate_decay_steps = 750000
27
+ training.learning_rate_decay_rate = 0.464158
28
+ training.learning_rate_staircase = True
29
+ training.num_steps = 3000000
30
+
31
+ # in the sweep
32
+ training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
33
+ training_dataset.batch_size = 8
34
+ training_dataset.crop_size = 256
35
+
36
+ eval_datasets.batch_size = 1
37
+ eval_datasets.max_examples = -1
38
+ # eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
39
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
40
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
41
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
42
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
43
+ # eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
44
+ eval_datasets.files = []
45
+ eval_datasets.names = []
46
+
47
+ # Training augmentation (in addition to random crop)
48
+ data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
49
+
50
+ # Loss functions
51
+ training_losses.loss_names = ['l1', 'vgg', 'style']
52
+ training_losses.loss_weight_schedules = [
53
+ @tf.keras.optimizers.schedules.PiecewiseConstantDecay,
54
+ @tf.keras.optimizers.schedules.PiecewiseConstantDecay,
55
+ @tf.keras.optimizers.schedules.PiecewiseConstantDecay]
56
+ # Increase the weight of style loss at 1.5M steps.
57
+ training_losses.loss_weight_parameters = [
58
+ {'boundaries':[0], 'values':[1.0, 1.0]},
59
+ {'boundaries':[1500000], 'values':[1.0, 0.25]},
60
+ {'boundaries':[1500000], 'values':[0.0, 40.0]}]
61
+
62
+ test_losses.loss_names = ['l1', 'psnr', 'ssim']
63
+ test_losses.loss_weights = [1.0, 1.0, 1.0]
64
+
65
+ vgg.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
66
+ style.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
training/config/film_net-VGG.gin ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ model.name = 'film_net'
16
+
17
+ film_net.pyramid_levels = 7
18
+ film_net.fusion_pyramid_levels = 5
19
+ film_net.specialized_levels = 3
20
+ film_net.sub_levels = 4
21
+ film_net.flow_convs = [3, 3, 3, 3]
22
+ film_net.flow_filters = [32, 64, 128, 256]
23
+ film_net.filters = 64
24
+
25
+ training.learning_rate = 0.0001
26
+ training.learning_rate_decay_steps = 750000
27
+ training.learning_rate_decay_rate = 0.464158
28
+ training.learning_rate_staircase = True
29
+ training.num_steps = 3000000
30
+
31
+ # in the sweep
32
+ training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
33
+ training_dataset.batch_size = 8
34
+ training_dataset.crop_size = 256
35
+
36
+ eval_datasets.batch_size = 1
37
+ eval_datasets.max_examples = -1
38
+ # eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
39
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
40
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
41
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
42
+ # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
43
+ # eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
44
+ eval_datasets.files = []
45
+ eval_datasets.names = []
46
+
47
+ # Training augmentation (in addition to random crop)
48
+ data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
49
+
50
+ # Loss functions
51
+ training_losses.loss_names = ['l1', 'vgg']
52
+ training_losses.loss_weight_schedules = [
53
+ @tf.keras.optimizers.schedules.PiecewiseConstantDecay,
54
+ @tf.keras.optimizers.schedules.PiecewiseConstantDecay]
55
+
56
+ # Decrease the weight of VGG loss at 1.5M steps.
57
+ training_losses.loss_weight_parameters = [
58
+ {'boundaries':[0], 'values':[1.0, 1.0]},
59
+ {'boundaries':[1500000], 'values':[1.0, 0.25]}]
60
+
61
+ test_losses.loss_names = ['l1', 'psnr', 'ssim']
62
+ test_losses.loss_weights = [1.0, 1.0, 1.0]
63
+
64
+ vgg.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
training/data_lib.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Dataset creation for frame interpolation."""
16
+ from typing import Callable, Dict, List, Optional
17
+
18
+ from absl import logging
19
+ import gin.tf
20
+ import tensorflow as tf
21
+
22
+
23
+ def _create_feature_map() -> Dict[str, tf.io.FixedLenFeature]:
24
+ """Creates the feature map for extracting the frame triplet."""
25
+ feature_map = {
26
+ 'frame_0/encoded':
27
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
28
+ 'frame_0/format':
29
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
30
+ 'frame_0/height':
31
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
32
+ 'frame_0/width':
33
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
34
+ 'frame_1/encoded':
35
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
36
+ 'frame_1/format':
37
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
38
+ 'frame_1/height':
39
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
40
+ 'frame_1/width':
41
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
42
+ 'frame_2/encoded':
43
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
44
+ 'frame_2/format':
45
+ tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
46
+ 'frame_2/height':
47
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
48
+ 'frame_2/width':
49
+ tf.io.FixedLenFeature((), tf.int64, default_value=0),
50
+ 'path':
51
+ tf.io.FixedLenFeature((), tf.string, default_value=''),
52
+ }
53
+ return feature_map
54
+
55
+
56
+ def _parse_example(sample):
57
+ """Parses a serialized sample.
58
+
59
+ Args:
60
+ sample: A serialized tf.Example to be parsed.
61
+
62
+ Returns:
63
+ dictionary containing the following:
64
+ encoded_image
65
+ image_height
66
+ image_width
67
+ """
68
+ feature_map = _create_feature_map()
69
+ features = tf.io.parse_single_example(sample, feature_map)
70
+ output_dict = {
71
+ 'x0': tf.io.decode_image(features['frame_0/encoded'], dtype=tf.float32),
72
+ 'x1': tf.io.decode_image(features['frame_2/encoded'], dtype=tf.float32),
73
+ 'y': tf.io.decode_image(features['frame_1/encoded'], dtype=tf.float32),
74
+ # The fractional time value of frame_1 is not included in our tfrecords,
75
+ # but is always at 0.5. The model will expect this to be specificed, so
76
+ # we insert it here.
77
+ 'time': 0.5,
78
+ # Store the original mid frame filepath for identifying examples.
79
+ 'path': features['path'],
80
+ }
81
+
82
+ return output_dict
83
+
84
+
85
+ def _random_crop_images(crop_size: int, images: tf.Tensor,
86
+ total_channel_size: int) -> tf.Tensor:
87
+ """Crops the tensor with random offset to the given size."""
88
+ if crop_size > 0:
89
+ crop_shape = tf.constant([crop_size, crop_size, total_channel_size])
90
+ images = tf.image.random_crop(images, crop_shape)
91
+ return images
92
+
93
+
94
+ def crop_example(example: tf.Tensor, crop_size: int,
95
+ crop_keys: Optional[List[str]] = None):
96
+ """Random crops selected images in the example to given size and keys.
97
+
98
+ Args:
99
+ example: Input tensor representing images to be cropped.
100
+ crop_size: The size to crop images to. This value is used for both
101
+ height and width.
102
+ crop_keys: The images in the input example to crop.
103
+
104
+ Returns:
105
+ Example with cropping applied to selected images.
106
+ """
107
+ if crop_keys is None:
108
+ crop_keys = ['x0', 'x1', 'y']
109
+ channels = [3, 3, 3]
110
+
111
+ # Stack images along channel axis, and perform a random crop once.
112
+ image_to_crop = [example[key] for key in crop_keys]
113
+ stacked_images = tf.concat(image_to_crop, axis=-1)
114
+ cropped_images = _random_crop_images(crop_size, stacked_images, sum(channels))
115
+ cropped_images = tf.split(
116
+ cropped_images, num_or_size_splits=channels, axis=-1)
117
+ for key, cropped_image in zip(crop_keys, cropped_images):
118
+ example[key] = cropped_image
119
+ return example
120
+
121
+
122
+ def apply_data_augmentation(
123
+ augmentation_fns: Dict[str, Callable[..., tf.Tensor]],
124
+ example: tf.Tensor,
125
+ augmentation_keys: Optional[List[str]] = None) -> tf.Tensor:
126
+ """Applies random augmentation in succession to selected image keys.
127
+
128
+ Args:
129
+ augmentation_fns: A Dict of Callables to data augmentation functions.
130
+ example: Input tensor representing images to be augmented.
131
+ augmentation_keys: The images in the input example to augment.
132
+
133
+ Returns:
134
+ Example with augmentation applied to selected images.
135
+ """
136
+ if augmentation_keys is None:
137
+ augmentation_keys = ['x0', 'x1', 'y']
138
+
139
+ # Apply each augmentation in sequence
140
+ augmented_images = {key: example[key] for key in augmentation_keys}
141
+ for augmentation_function in augmentation_fns.values():
142
+ augmented_images = augmentation_function(augmented_images)
143
+
144
+ for key in augmentation_keys:
145
+ example[key] = augmented_images[key]
146
+ return example
147
+
148
+
149
+ def _create_from_tfrecord(batch_size, file, augmentation_fns,
150
+ crop_size) -> tf.data.Dataset:
151
+ """Creates a dataset from TFRecord."""
152
+ dataset = tf.data.TFRecordDataset(file)
153
+ dataset = dataset.map(
154
+ _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
155
+
156
+ # Perform data_augmentation before cropping and batching
157
+ if augmentation_fns is not None:
158
+ dataset = dataset.map(
159
+ lambda x: apply_data_augmentation(augmentation_fns, x),
160
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
161
+
162
+ if crop_size > 0:
163
+ dataset = dataset.map(
164
+ lambda x: crop_example(x, crop_size=crop_size),
165
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
166
+ dataset = dataset.batch(batch_size, drop_remainder=True)
167
+ return dataset
168
+
169
+
170
+ def _generate_sharded_filenames(filename: str) -> List[str]:
171
+ """Generates filenames of the each file in the sharded filepath.
172
+
173
+ Based on github.com/google/revisiting-self-supervised/blob/master/datasets.py.
174
+
175
+ Args:
176
+ filename: The sharded filepath.
177
+
178
+ Returns:
179
+ A list of filepaths for each file in the shard.
180
+ """
181
+ base, count = filename.split('@')
182
+ count = int(count)
183
+ return ['{}-{:05d}-of-{:05d}'.format(base, i, count) for i in range(count)]
184
+
185
+
186
+ def _create_from_sharded_tfrecord(batch_size,
187
+ train_mode,
188
+ file,
189
+ augmentation_fns,
190
+ crop_size,
191
+ max_examples=-1) -> tf.data.Dataset:
192
+ """Creates a dataset from a sharded tfrecord."""
193
+ dataset = tf.data.Dataset.from_tensor_slices(
194
+ _generate_sharded_filenames(file))
195
+
196
+ # pylint: disable=g-long-lambda
197
+ dataset = dataset.interleave(
198
+ lambda x: _create_from_tfrecord(
199
+ batch_size,
200
+ file=x,
201
+ augmentation_fns=augmentation_fns,
202
+ crop_size=crop_size),
203
+ num_parallel_calls=tf.data.AUTOTUNE,
204
+ deterministic=not train_mode)
205
+ # pylint: enable=g-long-lambda
206
+ dataset = dataset.prefetch(buffer_size=2)
207
+ if max_examples > 0:
208
+ return dataset.take(max_examples)
209
+ return dataset
210
+
211
+
212
+ @gin.configurable('training_dataset')
213
+ def create_training_dataset(
214
+ batch_size: int,
215
+ file: Optional[str] = None,
216
+ files: Optional[List[str]] = None,
217
+ crop_size: int = -1,
218
+ crop_sizes: Optional[List[int]] = None,
219
+ augmentation_fns: Optional[Dict[str, Callable[..., tf.Tensor]]] = None
220
+ ) -> tf.data.Dataset:
221
+ """Creates the training dataset.
222
+
223
+ The given tfrecord should contain data in a format produced by
224
+ frame_interpolation/datasets/create_*_tfrecord.py
225
+
226
+ Args:
227
+ batch_size: The number of images to batch per example.
228
+ file: (deprecated) A path to a sharded tfrecord in <tfrecord>@N format.
229
+ Deprecated. Use 'files' instead.
230
+ files: A list of paths to sharded tfrecords in <tfrecord>@N format.
231
+ crop_size: (deprecated) If > 0, images are cropped to crop_size x crop_size
232
+ using tensorflow's random cropping. Deprecated: use 'files' and
233
+ 'crop_sizes' instead.
234
+ crop_sizes: List of crop sizes. If > 0, images are cropped to
235
+ crop_size x crop_size using tensorflow's random cropping.
236
+ augmentation_fns: A Dict of Callables to data augmentation functions.
237
+ Returns:
238
+ A tensorflow dataset for accessing examples that contain the input images
239
+ 'x0', 'x1', ground truth 'y' and time of the ground truth 'time'=[0,1] in a
240
+ dictionary of tensors.
241
+ """
242
+ if file:
243
+ logging.warning('gin-configurable training_dataset.file is deprecated. '
244
+ 'Use training_dataset.files instead.')
245
+ return _create_from_sharded_tfrecord(batch_size, True, file,
246
+ augmentation_fns, crop_size)
247
+ else:
248
+ if not crop_sizes or len(crop_sizes) != len(files):
249
+ raise ValueError('Please pass crop_sizes[] with training_dataset.files.')
250
+ if crop_size > 0:
251
+ raise ValueError(
252
+ 'crop_size should not be used with files[], use crop_sizes[] instead.'
253
+ )
254
+ tables = []
255
+ for file, crop_size in zip(files, crop_sizes):
256
+ tables.append(
257
+ _create_from_sharded_tfrecord(batch_size, True, file,
258
+ augmentation_fns, crop_size))
259
+ return tf.data.experimental.sample_from_datasets(tables)
260
+
261
+
262
+ @gin.configurable('eval_datasets')
263
+ def create_eval_datasets(batch_size: int,
264
+ files: List[str],
265
+ names: List[str],
266
+ crop_size: int = -1,
267
+ max_examples: int = -1) -> Dict[str, tf.data.Dataset]:
268
+ """Creates the evaluation datasets.
269
+
270
+ As opposed to create_training_dataset this function makes sure that the
271
+ examples for each dataset are always read in a deterministic (same) order.
272
+
273
+ Each given tfrecord should contain data in a format produced by
274
+ frame_interpolation/datasets/create_*_tfrecord.py
275
+
276
+ The (batch_size, crop_size, max_examples) are specified for all eval datasets.
277
+
278
+ Args:
279
+ batch_size: The number of images to batch per example.
280
+ files: List of paths to a sharded tfrecord in <tfrecord>@N format.
281
+ names: List of names of eval datasets.
282
+ crop_size: If > 0, images are cropped to crop_size x crop_size using
283
+ tensorflow's random cropping.
284
+ max_examples: If > 0, truncate the dataset to 'max_examples' in length. This
285
+ can be useful for speeding up evaluation loop in case the tfrecord for the
286
+ evaluation set is very large.
287
+ Returns:
288
+ A dict of name to tensorflow dataset for accessing examples that contain the
289
+ input images 'x0', 'x1', ground truth 'y' and time of the ground truth
290
+ 'time'=[0,1] in a dictionary of tensors.
291
+ """
292
+ return {
293
+ name: _create_from_sharded_tfrecord(batch_size, False, file, None,
294
+ crop_size, max_examples)
295
+ for name, file in zip(names, files)
296
+ }
training/eval_lib.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Evaluation library for frame interpolation."""
16
+ from typing import Dict, Mapping, Text
17
+
18
+ from absl import logging
19
+ import tensorflow as tf
20
+
21
+
22
+ def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor:
23
+ """Collect tensors of the different replicas into a list."""
24
+ return tf.nest.flatten(tensors, expand_composites=True)
25
+
26
+
27
+ @tf.function
28
+ def _distributed_eval_step(strategy: tf.distribute.Strategy,
29
+ batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
30
+ metrics: Dict[Text, tf.keras.metrics.Metric],
31
+ checkpoint_step: int) -> Dict[Text, tf.Tensor]:
32
+ """Distributed eval step.
33
+
34
+ Args:
35
+ strategy: A Tensorflow distribution strategy.
36
+ batch: A batch of training examples.
37
+ model: The Keras model to evaluate.
38
+ metrics: The Keras metrics used for evaluation (a dictionary).
39
+ checkpoint_step: The iteration number at which the checkpoint is restored.
40
+
41
+ Returns:
42
+ list of predictions from each replica.
43
+ """
44
+
45
+ def _eval_step(
46
+ batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
47
+ """Eval for one step."""
48
+ predictions = model(batch, training=False)
49
+ # Note: these metrics expect batch and prediction dictionaries rather than
50
+ # tensors like standard TF metrics do. This allows our losses and metrics to
51
+ # use a richer set of inputs than just the predicted final image.
52
+ for metric in metrics.values():
53
+ metric.update_state(batch, predictions, checkpoint_step=checkpoint_step)
54
+ return predictions
55
+
56
+ return strategy.run(_eval_step, args=(batch,))
57
+
58
+
59
+ def _summarize_image_tensors(combined, prefix, step):
60
+ for name in combined:
61
+ image = combined[name]
62
+ if isinstance(image, tf.Tensor):
63
+ if len(image.shape) == 4 and (image.shape[-1] == 1 or
64
+ image.shape[-1] == 3):
65
+ tf.summary.image(prefix + '/' + name, image, step=step)
66
+
67
+
68
+ def eval_loop(strategy: tf.distribute.Strategy,
69
+ eval_base_folder: str,
70
+ model: tf.keras.Model,
71
+ metrics: Dict[str, tf.keras.metrics.Metric],
72
+ datasets: Mapping[str, tf.data.Dataset],
73
+ summary_writer: tf.summary.SummaryWriter,
74
+ checkpoint_step: int):
75
+ """Eval function that is strategy agnostic.
76
+
77
+ Args:
78
+ strategy: A Tensorflow distributed strategy.
79
+ eval_base_folder: A path to where the summaries event files and
80
+ checkpoints will be saved.
81
+ model: A function that returns the model.
82
+ metrics: A function that returns the metrics dictionary.
83
+ datasets: A dict of tf.data.Dataset to evaluate on.
84
+ summary_writer: Eval summary writer.
85
+ checkpoint_step: The number of iterations completed.
86
+ """
87
+ logging.info('Saving eval summaries to: %s...', eval_base_folder)
88
+ summary_writer.set_as_default()
89
+
90
+ for dataset_name, dataset in datasets.items():
91
+ for metric in metrics.values():
92
+ metric.reset_states()
93
+
94
+ logging.info('Loading %s testing data ...', dataset_name)
95
+ dataset = strategy.experimental_distribute_dataset(dataset)
96
+
97
+ logging.info('Evaluating %s ...', dataset_name)
98
+ batch_idx = 0
99
+ max_batches_to_summarize = 10
100
+ for batch in dataset:
101
+ predictions = _distributed_eval_step(strategy, batch, model, metrics,
102
+ checkpoint_step)
103
+ # Clip interpolator output to [0,1]. Clipping is done only
104
+ # on the eval loop to get better metrics, but not on the training loop
105
+ # so gradients are not killed.
106
+ if strategy.num_replicas_in_sync > 1:
107
+ predictions = {
108
+ 'image': tf.concat(predictions['image'].values, axis=0)
109
+ }
110
+ predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.)
111
+ if batch_idx % 10 == 0:
112
+ logging.info('Evaluating batch %s', batch_idx)
113
+ batch_idx = batch_idx + 1
114
+ if batch_idx < max_batches_to_summarize:
115
+ # Loop through the global batch:
116
+ prefix = f'{dataset_name}/eval_{batch_idx}'
117
+ # Find all tensors that look like images, and summarize:
118
+ combined = {**batch, **predictions}
119
+ _summarize_image_tensors(combined, prefix, step=checkpoint_step)
120
+
121
+ elif batch_idx == max_batches_to_summarize:
122
+ tf.summary.flush()
123
+
124
+ for name, metric in metrics.items():
125
+ tf.summary.scalar(
126
+ f'{dataset_name}/{name}', metric.result(), step=checkpoint_step)
127
+ tf.summary.flush()
128
+ logging.info('Step {:2}, {} {}'.format(checkpoint_step,
129
+ f'{dataset_name}/{name}',
130
+ metric.result().numpy()))
131
+ metric.reset_states()
training/metrics_lib.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """A library for instantiating frame interpolation evaluation metrics."""
16
+
17
+ from typing import Callable, Dict, Text
18
+
19
+ from ..losses import losses
20
+ import tensorflow as tf
21
+
22
+
23
+ class TrainLossMetric(tf.keras.metrics.Metric):
24
+ """Compute training loss for our example and prediction format.
25
+
26
+ The purpose of this is to ensure that we always include a loss that is exactly
27
+ like the training loss into the evaluation in order to detect possible
28
+ overfitting.
29
+ """
30
+
31
+ def __init__(self, name='eval_loss', **kwargs):
32
+ super(TrainLossMetric, self).__init__(name=name, **kwargs)
33
+ self.acc = self.add_weight(name='train_metric_acc', initializer='zeros')
34
+ self.count = self.add_weight(name='train_metric_count', initializer='zeros')
35
+
36
+ def update_state(self,
37
+ batch,
38
+ predictions,
39
+ sample_weight=None,
40
+ checkpoint_step=0):
41
+ loss_functions = losses.training_losses()
42
+ loss_list = []
43
+ for (loss_value, loss_weight) in loss_functions.values():
44
+ loss_list.append(
45
+ loss_value(batch, predictions) * loss_weight(checkpoint_step))
46
+ loss = tf.add_n(loss_list)
47
+ self.acc.assign_add(loss)
48
+ self.count.assign_add(1)
49
+
50
+ def result(self):
51
+ return self.acc / self.count
52
+
53
+ def reset_states(self):
54
+ self.acc.assign(0)
55
+ self.count.assign(0)
56
+
57
+
58
+ class L1Metric(tf.keras.metrics.Metric):
59
+ """Compute L1 over our training example and prediction format.
60
+
61
+ The purpose of this is to ensure that we have at least one metric that is
62
+ compatible across all eval the session and allows us to quickly compare models
63
+ against each other.
64
+ """
65
+
66
+ def __init__(self, name='eval_loss', **kwargs):
67
+ super(L1Metric, self).__init__(name=name, **kwargs)
68
+ self.acc = self.add_weight(name='l1_metric_acc', initializer='zeros')
69
+ self.count = self.add_weight(name='l1_metric_count', initializer='zeros')
70
+
71
+ def update_state(self, batch, prediction, sample_weight=None,
72
+ checkpoint_step=0):
73
+ self.acc.assign_add(losses.l1_loss(batch, prediction))
74
+ self.count.assign_add(1)
75
+
76
+ def result(self):
77
+ return self.acc / self.count
78
+
79
+ def reset_states(self):
80
+ self.acc.assign(0)
81
+ self.count.assign(0)
82
+
83
+
84
+ class GenericLossMetric(tf.keras.metrics.Metric):
85
+ """Metric based on any loss function."""
86
+
87
+ def __init__(self, name: str, loss: Callable[..., tf.Tensor],
88
+ weight: Callable[..., tf.Tensor], **kwargs):
89
+ """Initializes a metric based on a loss function and a weight schedule.
90
+
91
+ Args:
92
+ name: The name of the metric.
93
+ loss: The callable loss that calculates a loss value for a (prediction,
94
+ target) pair.
95
+ weight: The callable weight scheduling function that samples a weight
96
+ based on iteration.
97
+ **kwargs: Any additional keyword arguments to be passed.
98
+ """
99
+ super(GenericLossMetric, self).__init__(name=name, **kwargs)
100
+ self.acc = self.add_weight(name='loss_metric_acc', initializer='zeros')
101
+ self.count = self.add_weight(name='loss_metric_count', initializer='zeros')
102
+ self.loss = loss
103
+ self.weight = weight
104
+
105
+ def update_state(self,
106
+ batch,
107
+ predictions,
108
+ sample_weight=None,
109
+ checkpoint_step=0):
110
+ self.acc.assign_add(
111
+ self.loss(batch, predictions) * self.weight(checkpoint_step))
112
+ self.count.assign_add(1)
113
+
114
+ def result(self):
115
+ return self.acc / self.count
116
+
117
+ def reset_states(self):
118
+ self.acc.assign(0)
119
+ self.count.assign(0)
120
+
121
+
122
+ def create_metrics_fn() -> Dict[Text, tf.keras.metrics.Metric]:
123
+ """Create evaluation metrics.
124
+
125
+ L1 and total training loss are added by default.
126
+ The rest are the configured by the test_losses item via gin.
127
+
128
+ Returns:
129
+ A dictionary from metric name to Keras Metric object.
130
+ """
131
+ metrics = {}
132
+ # L1 is explicitly added just so we always have some consistent numbers around
133
+ # to compare across sessions.
134
+ metrics['l1'] = L1Metric()
135
+ # We also always include training loss for the eval set to detect overfitting:
136
+ metrics['training_loss'] = TrainLossMetric()
137
+
138
+ test_losses = losses.test_losses()
139
+ for loss_name, (loss_value, loss_weight) in test_losses.items():
140
+ metrics[loss_name] = GenericLossMetric(
141
+ name=loss_name, loss=loss_value, weight=loss_weight)
142
+ return metrics
training/model_lib.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """A library for instantiating the model for training frame interpolation.
16
+
17
+ All models are expected to use three inputs: input image batches 'x0' and 'x1'
18
+ and 'time', the fractional time where the output should be generated.
19
+
20
+ The models are expected to output the prediction as a dictionary that contains
21
+ at least the predicted image batch as 'image' plus optional data for debug,
22
+ analysis or custom losses.
23
+ """
24
+
25
+ import gin.tf
26
+ from ..models.film_net import interpolator as film_net_interpolator
27
+ from ..models.film_net import options as film_net_options
28
+
29
+ import tensorflow as tf
30
+
31
+
32
+ @gin.configurable('model')
33
+ def create_model(name: str) -> tf.keras.Model:
34
+ """Creates the frame interpolation model based on given model name."""
35
+ if name == 'film_net':
36
+ return _create_film_net_model() # pylint: disable=no-value-for-parameter
37
+ else:
38
+ raise ValueError(f'Model {name} not implemented.')
39
+
40
+
41
+ def _create_film_net_model() -> tf.keras.Model:
42
+ """Creates the film_net interpolator."""
43
+ # Options are gin-configured in the Options class directly.
44
+ options = film_net_options.Options()
45
+
46
+ x0 = tf.keras.Input(
47
+ shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x0')
48
+ x1 = tf.keras.Input(
49
+ shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x1')
50
+ time = tf.keras.Input(
51
+ shape=(1,), batch_size=None, dtype=tf.float32, name='time')
52
+
53
+ return film_net_interpolator.create_model(x0, x1, time, options)
training/train.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""The training loop for frame interpolation.
16
+
17
+ gin_config: The gin configuration file containing model, losses and datasets.
18
+
19
+ To run on GPUs:
20
+ python3 -m frame_interpolation.training.train \
21
+ --gin_config <path to network.gin> \
22
+ --base_folder <base folder for all training runs> \
23
+ --label <descriptive label for the run>
24
+
25
+ To debug the training loop on CPU:
26
+ python3 -m frame_interpolation.training.train \
27
+ --gin_config <path to config.gin> \
28
+ --base_folder /tmp
29
+ --label test_run \
30
+ --mode cpu
31
+
32
+ The training output directory will be created at <base_folder>/<label>.
33
+ """
34
+ import os
35
+
36
+ from . import augmentation_lib
37
+ from . import data_lib
38
+ from . import eval_lib
39
+ from . import metrics_lib
40
+ from . import model_lib
41
+ from . import train_lib
42
+ from absl import app
43
+ from absl import flags
44
+ from absl import logging
45
+ import gin.tf
46
+ from ..losses import losses
47
+
48
+ # Reduce tensorflow logs to ERRORs only.
49
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
50
+ import tensorflow as tf # pylint: disable=g-import-not-at-top
51
+ tf.get_logger().setLevel('ERROR')
52
+
53
+
54
+ _GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.')
55
+ _LABEL = flags.DEFINE_string('label', 'run0',
56
+ 'Descriptive label for this run.')
57
+ _BASE_FOLDER = flags.DEFINE_string('base_folder', None,
58
+ 'Path to checkpoints/summaries.')
59
+ _MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'],
60
+ 'Distributed strategy approach.')
61
+
62
+
63
+ @gin.configurable('training')
64
+ class TrainingOptions(object):
65
+ """Training-related options."""
66
+
67
+ def __init__(self, learning_rate: float, learning_rate_decay_steps: int,
68
+ learning_rate_decay_rate: int, learning_rate_staircase: int,
69
+ num_steps: int):
70
+ self.learning_rate = learning_rate
71
+ self.learning_rate_decay_steps = learning_rate_decay_steps
72
+ self.learning_rate_decay_rate = learning_rate_decay_rate
73
+ self.learning_rate_staircase = learning_rate_staircase
74
+ self.num_steps = num_steps
75
+
76
+
77
+ def main(argv):
78
+ if len(argv) > 1:
79
+ raise app.UsageError('Too many command-line arguments.')
80
+
81
+ output_dir = os.path.join(_BASE_FOLDER.value, _LABEL.value)
82
+ logging.info('Creating output_dir @ %s ...', output_dir)
83
+
84
+ # Copy config file to <base_folder>/<label>/config.gin.
85
+ tf.io.gfile.makedirs(output_dir)
86
+ tf.io.gfile.copy(
87
+ _GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True)
88
+
89
+ gin.external_configurable(
90
+ tf.keras.optimizers.schedules.PiecewiseConstantDecay,
91
+ module='tf.keras.optimizers.schedules')
92
+
93
+ gin_configs = [_GIN_CONFIG.value]
94
+ gin.parse_config_files_and_bindings(
95
+ config_files=gin_configs, bindings=None, skip_unknown=True)
96
+
97
+ training_options = TrainingOptions() # pylint: disable=no-value-for-parameter
98
+
99
+ learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
100
+ training_options.learning_rate,
101
+ training_options.learning_rate_decay_steps,
102
+ training_options.learning_rate_decay_rate,
103
+ training_options.learning_rate_staircase,
104
+ name='learning_rate')
105
+
106
+ # Initialize data augmentation functions
107
+ augmentation_fns = augmentation_lib.data_augmentations()
108
+
109
+ saved_model_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value,
110
+ 'saved_model')
111
+ train_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train')
112
+ eval_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'eval')
113
+
114
+ train_lib.train(
115
+ strategy=train_lib.get_strategy(_MODE.value),
116
+ train_folder=train_folder,
117
+ saved_model_folder=saved_model_folder,
118
+ n_iterations=training_options.num_steps,
119
+ create_model_fn=model_lib.create_model,
120
+ create_losses_fn=losses.training_losses,
121
+ create_metrics_fn=metrics_lib.create_metrics_fn,
122
+ dataset=data_lib.create_training_dataset(
123
+ augmentation_fns=augmentation_fns),
124
+ learning_rate=learning_rate,
125
+ eval_loop_fn=eval_lib.eval_loop,
126
+ eval_folder=eval_folder,
127
+ eval_datasets=data_lib.create_eval_datasets() or None)
128
+
129
+
130
+ if __name__ == '__main__':
131
+ app.run(main)
training/train_lib.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ r"""Training library for frame interpolation using distributed strategy."""
16
+ import functools
17
+ from typing import Any, Callable, Dict, Text, Tuple
18
+
19
+ from absl import logging
20
+ import tensorflow as tf
21
+
22
+
23
+ def _concat_tensors(tensors: tf.Tensor) -> tf.Tensor:
24
+ """Concat tensors of the different replicas."""
25
+ return tf.concat(tf.nest.flatten(tensors, expand_composites=True), axis=0)
26
+
27
+
28
+ @tf.function
29
+ def _distributed_train_step(strategy: tf.distribute.Strategy,
30
+ batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
31
+ loss_functions: Dict[Text,
32
+ Tuple[Callable[..., tf.Tensor],
33
+ Callable[...,
34
+ tf.Tensor]]],
35
+ optimizer: tf.keras.optimizers.Optimizer,
36
+ iterations: int) -> Dict[Text, Any]:
37
+ """Distributed training step.
38
+
39
+ Args:
40
+ strategy: A Tensorflow distribution strategy.
41
+ batch: A batch of training examples.
42
+ model: The Keras model to train.
43
+ loss_functions: The list of Keras losses used to train the model.
44
+ optimizer: The Keras optimizer used to train the model.
45
+ iterations: Iteration number used to sample weights to each loss.
46
+
47
+ Returns:
48
+ A dictionary of train step outputs.
49
+ """
50
+
51
+ def _train_step(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
52
+ """Train for one step."""
53
+ with tf.GradientTape() as tape:
54
+ predictions = model(batch, training=True)
55
+ losses = []
56
+ for (loss_value, loss_weight) in loss_functions.values():
57
+ losses.append(loss_value(batch, predictions) * loss_weight(iterations))
58
+ loss = tf.add_n(losses)
59
+ grads = tape.gradient(loss, model.trainable_variables)
60
+ optimizer.apply_gradients(zip(grads, model.trainable_variables))
61
+ # post process for visualization
62
+ all_data = {'loss': loss}
63
+ all_data.update(batch)
64
+ all_data.update(predictions)
65
+ return all_data
66
+
67
+ step_outputs = strategy.run(_train_step, args=(batch,))
68
+
69
+ loss = strategy.reduce(
70
+ tf.distribute.ReduceOp.MEAN, step_outputs['loss'], axis=None)
71
+
72
+ x0 = _concat_tensors(step_outputs['x0'])
73
+ x1 = _concat_tensors(step_outputs['x1'])
74
+ y = _concat_tensors(step_outputs['y'])
75
+ pred_y = _concat_tensors(step_outputs['image'])
76
+
77
+ scalar_summaries = {'training_loss': loss}
78
+
79
+ image_summaries = {
80
+ 'x0': x0,
81
+ 'x1': x1,
82
+ 'y': y,
83
+ 'pred_y': pred_y
84
+ }
85
+
86
+ extra_images = {
87
+ 'importance0', 'importance1', 'x0_warped', 'x1_warped', 'fg_image',
88
+ 'bg_image', 'fg_alpha', 'x1_unfiltered_warped'
89
+ }
90
+ for image in extra_images:
91
+ if image in step_outputs:
92
+ image_summaries[image] = _concat_tensors(step_outputs[image])
93
+
94
+ return {
95
+ 'loss': loss,
96
+ 'scalar_summaries': scalar_summaries,
97
+ 'image_summaries': {
98
+ f'training/{name}': value for name, value in image_summaries.items()
99
+ }
100
+ }
101
+
102
+
103
+ def _summary_writer(summaries_dict: Dict[Text, Any]) -> None:
104
+ """Adds scalar and image summaries."""
105
+ # Adds scalar summaries.
106
+ for key, scalars in summaries_dict['scalar_summaries'].items():
107
+ tf.summary.scalar(key, scalars)
108
+ # Adds image summaries.
109
+ for key, images in summaries_dict['image_summaries'].items():
110
+ tf.summary.image(key, tf.clip_by_value(images, 0.0, 1.0))
111
+ tf.summary.histogram(key + '_h', images)
112
+
113
+
114
+ def train_loop(
115
+ strategy: tf.distribute.Strategy,
116
+ train_set: tf.data.Dataset,
117
+ create_model_fn: Callable[..., tf.keras.Model],
118
+ create_losses_fn: Callable[..., Dict[str, Tuple[Callable[..., tf.Tensor],
119
+ Callable[..., tf.Tensor]]]],
120
+ create_optimizer_fn: Callable[..., tf.keras.optimizers.Optimizer],
121
+ distributed_train_step_fn: Callable[[
122
+ tf.distribute.Strategy, Dict[str, tf.Tensor], tf.keras.Model, Dict[
123
+ str,
124
+ Tuple[Callable[..., tf.Tensor],
125
+ Callable[..., tf.Tensor]]], tf.keras.optimizers.Optimizer, int
126
+ ], Dict[str, Any]],
127
+ eval_loop_fn: Callable[..., None],
128
+ create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
129
+ eval_folder: Dict[str, Any],
130
+ eval_datasets: Dict[str, tf.data.Dataset],
131
+ summary_writer_fn: Callable[[Dict[str, Any]], None],
132
+ train_folder: str,
133
+ saved_model_folder: str,
134
+ num_iterations: int,
135
+ save_summaries_frequency: int = 500,
136
+ save_checkpoint_frequency: int = 500,
137
+ checkpoint_max_to_keep: int = 10,
138
+ checkpoint_save_every_n_hours: float = 2.,
139
+ timing_frequency: int = 100,
140
+ logging_frequency: int = 10):
141
+ """A Tensorflow 2 eager mode training loop.
142
+
143
+ Args:
144
+ strategy: A Tensorflow distributed strategy.
145
+ train_set: A tf.data.Dataset to loop through for training.
146
+ create_model_fn: A callable that returns a tf.keras.Model.
147
+ create_losses_fn: A callable that returns a tf.keras.losses.Loss.
148
+ create_optimizer_fn: A callable that returns a
149
+ tf.keras.optimizers.Optimizer.
150
+ distributed_train_step_fn: A callable that takes a distribution strategy, a
151
+ Dict[Text, tf.Tensor] holding the batch of training data, a
152
+ tf.keras.Model, a tf.keras.losses.Loss, a tf.keras.optimizers.Optimizer,
153
+ iteartion number to sample a weight value to loos functions,
154
+ and returns a dictionary to be passed to the summary_writer_fn.
155
+ eval_loop_fn: Eval loop function.
156
+ create_metrics_fn: create_metric_fn.
157
+ eval_folder: A path to where the summary event files and checkpoints will be
158
+ saved.
159
+ eval_datasets: A dictionary of evalution tf.data.Dataset to loop through for
160
+ evaluation.
161
+ summary_writer_fn: A callable that takes the output of
162
+ distributed_train_step_fn and writes summaries to be visualized in
163
+ TensorBoard.
164
+ train_folder: A path to where the summaries event files and checkpoints
165
+ will be saved.
166
+ saved_model_folder: A path to where the saved models are stored.
167
+ num_iterations: An integer, the number of iterations to train for.
168
+ save_summaries_frequency: The iteration frequency with which summaries are
169
+ saved.
170
+ save_checkpoint_frequency: The iteration frequency with which model
171
+ checkpoints are saved.
172
+ checkpoint_max_to_keep: The maximum number of checkpoints to keep.
173
+ checkpoint_save_every_n_hours: The frequency in hours to keep checkpoints.
174
+ timing_frequency: The iteration frequency with which to log timing.
175
+ logging_frequency: How often to output with logging.info().
176
+ """
177
+ logging.info('Creating training tensorboard summaries ...')
178
+ summary_writer = tf.summary.create_file_writer(train_folder)
179
+
180
+ if eval_datasets is not None:
181
+ logging.info('Creating eval tensorboard summaries ...')
182
+ eval_summary_writer = tf.summary.create_file_writer(eval_folder)
183
+
184
+ train_set = strategy.experimental_distribute_dataset(train_set)
185
+ with strategy.scope():
186
+ logging.info('Building model ...')
187
+ model = create_model_fn()
188
+ loss_functions = create_losses_fn()
189
+ optimizer = create_optimizer_fn()
190
+ if eval_datasets is not None:
191
+ metrics = create_metrics_fn()
192
+
193
+ logging.info('Creating checkpoint ...')
194
+ checkpoint = tf.train.Checkpoint(
195
+ model=model,
196
+ optimizer=optimizer,
197
+ step=optimizer.iterations,
198
+ epoch=tf.Variable(0, dtype=tf.int64, trainable=False),
199
+ training_finished=tf.Variable(False, dtype=tf.bool, trainable=False))
200
+
201
+ logging.info('Restoring old model (if exists) ...')
202
+ checkpoint_manager = tf.train.CheckpointManager(
203
+ checkpoint,
204
+ directory=train_folder,
205
+ max_to_keep=checkpoint_max_to_keep,
206
+ keep_checkpoint_every_n_hours=checkpoint_save_every_n_hours)
207
+
208
+ with strategy.scope():
209
+ if checkpoint_manager.latest_checkpoint:
210
+ checkpoint.restore(checkpoint_manager.latest_checkpoint)
211
+
212
+ logging.info('Creating Timer ...')
213
+ timer = tf.estimator.SecondOrStepTimer(every_steps=timing_frequency)
214
+ timer.update_last_triggered_step(optimizer.iterations.numpy())
215
+
216
+ logging.info('Training on devices: %s.', [
217
+ el.name.split('/physical_device:')[-1]
218
+ for el in tf.config.get_visible_devices()
219
+ ])
220
+
221
+ # Re-assign training_finished=False, in case we restored a checkpoint.
222
+ checkpoint.training_finished.assign(False)
223
+ while optimizer.iterations.numpy() < num_iterations:
224
+ for i_batch, batch in enumerate(train_set):
225
+ summary_writer.set_as_default()
226
+ iterations = optimizer.iterations.numpy()
227
+
228
+ if iterations % logging_frequency == 0:
229
+ # Log epoch, total iterations and batch index.
230
+ logging.info('epoch %d; iterations %d; i_batch %d',
231
+ checkpoint.epoch.numpy(), iterations,
232
+ i_batch)
233
+
234
+ # Break if the number of iterations exceeds the max.
235
+ if iterations >= num_iterations:
236
+ break
237
+
238
+ # Compute distributed step outputs.
239
+ distributed_step_outputs = distributed_train_step_fn(
240
+ strategy, batch, model, loss_functions, optimizer, iterations)
241
+
242
+ # Save checkpoint, and optionally run the eval loops.
243
+ if iterations % save_checkpoint_frequency == 0:
244
+ checkpoint_manager.save(checkpoint_number=iterations)
245
+ if eval_datasets is not None:
246
+ eval_loop_fn(
247
+ strategy=strategy,
248
+ eval_base_folder=eval_folder,
249
+ model=model,
250
+ metrics=metrics,
251
+ datasets=eval_datasets,
252
+ summary_writer=eval_summary_writer,
253
+ checkpoint_step=iterations)
254
+
255
+ # Write summaries.
256
+ if iterations % save_summaries_frequency == 0:
257
+ tf.summary.experimental.set_step(step=iterations)
258
+ summary_writer_fn(distributed_step_outputs)
259
+ tf.summary.scalar('learning_rate',
260
+ optimizer.learning_rate(iterations).numpy())
261
+
262
+ # Log steps/sec.
263
+ if timer.should_trigger_for_step(iterations):
264
+ elapsed_time, elapsed_steps = timer.update_last_triggered_step(
265
+ iterations)
266
+ if elapsed_time is not None:
267
+ steps_per_second = elapsed_steps / elapsed_time
268
+ tf.summary.scalar(
269
+ 'steps/sec', steps_per_second, step=optimizer.iterations)
270
+
271
+ # Increment epoch.
272
+ checkpoint.epoch.assign_add(1)
273
+
274
+ # Assign training_finished variable to True after training is finished and
275
+ # save the last checkpoint.
276
+ checkpoint.training_finished.assign(True)
277
+ checkpoint_manager.save(checkpoint_number=optimizer.iterations.numpy())
278
+
279
+ # Generate a saved model.
280
+ model.save(saved_model_folder)
281
+
282
+
283
+ def train(strategy: tf.distribute.Strategy, train_folder: str,
284
+ saved_model_folder: str, n_iterations: int,
285
+ create_model_fn: Callable[..., tf.keras.Model],
286
+ create_losses_fn: Callable[..., Dict[str,
287
+ Tuple[Callable[..., tf.Tensor],
288
+ Callable[...,
289
+ tf.Tensor]]]],
290
+ create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
291
+ dataset: tf.data.Dataset,
292
+ learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
293
+ eval_loop_fn: Callable[..., None],
294
+ eval_folder: str,
295
+ eval_datasets: Dict[str, tf.data.Dataset]):
296
+ """Training function that is strategy agnostic.
297
+
298
+ Args:
299
+ strategy: A Tensorflow distributed strategy.
300
+ train_folder: A path to where the summaries event files and checkpoints
301
+ will be saved.
302
+ saved_model_folder: A path to where the saved models are stored.
303
+ n_iterations: An integer, the number of iterations to train for.
304
+ create_model_fn: A callable that returns tf.keras.Model.
305
+ create_losses_fn: A callable that returns the losses.
306
+ create_metrics_fn: A function that returns the metrics dictionary.
307
+ dataset: The tensorflow dataset object.
308
+ learning_rate: Keras learning rate schedule object.
309
+ eval_loop_fn: eval loop function.
310
+ eval_folder: A path to where eval summaries event files and checkpoints
311
+ will be saved.
312
+ eval_datasets: The tensorflow evaluation dataset objects.
313
+ """
314
+ train_loop(
315
+ strategy=strategy,
316
+ train_set=dataset,
317
+ create_model_fn=create_model_fn,
318
+ create_losses_fn=create_losses_fn,
319
+ create_optimizer_fn=functools.partial(
320
+ tf.keras.optimizers.Adam, learning_rate=learning_rate),
321
+ distributed_train_step_fn=_distributed_train_step,
322
+ eval_loop_fn=eval_loop_fn,
323
+ create_metrics_fn=create_metrics_fn,
324
+ eval_folder=eval_folder,
325
+ eval_datasets=eval_datasets,
326
+ summary_writer_fn=_summary_writer,
327
+ train_folder=train_folder,
328
+ saved_model_folder=saved_model_folder,
329
+ num_iterations=n_iterations,
330
+ save_summaries_frequency=3000,
331
+ save_checkpoint_frequency=3000)
332
+
333
+
334
+ def get_strategy(mode) -> tf.distribute.Strategy:
335
+ """Creates a distributed strategy."""
336
+ strategy = None
337
+ if mode == 'cpu':
338
+ strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
339
+ elif mode == 'gpu':
340
+ strategy = tf.distribute.MirroredStrategy()
341
+ else:
342
+ raise ValueError('Unsupported distributed mode.')
343
+ return strategy