meepmoo commited on
Commit
208b0eb
·
verified ·
1 Parent(s): 950911f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +167 -0
  2. Dockerfile +45 -0
  3. LICENSE +201 -0
  4. README.md +376 -0
  5. README_zh-CN.md +375 -0
  6. __init__.py +3 -0
  7. app.py +49 -0
  8. cogvideox/__init__.py +0 -0
  9. cogvideox/api/api.py +173 -0
  10. cogvideox/api/post_infer.py +89 -0
  11. cogvideox/data/bucket_sampler.py +379 -0
  12. cogvideox/data/dataset_image.py +76 -0
  13. cogvideox/data/dataset_image_video.py +550 -0
  14. cogvideox/data/dataset_video.py +262 -0
  15. cogvideox/pipeline/pipeline_cogvideox.py +751 -0
  16. cogvideox/pipeline/pipeline_cogvideox_control.py +843 -0
  17. cogvideox/pipeline/pipeline_cogvideox_inpaint.py +1020 -0
  18. cogvideox/ui/ui.py +1614 -0
  19. cogvideox/utils/__init__.py +0 -0
  20. cogvideox/utils/lora_utils.py +477 -0
  21. cogvideox/utils/utils.py +208 -0
  22. cogvideox/video_caption/README.md +174 -0
  23. cogvideox/video_caption/README_zh-CN.md +159 -0
  24. cogvideox/video_caption/beautiful_prompt.py +103 -0
  25. cogvideox/video_caption/caption_rewrite.py +224 -0
  26. cogvideox/video_caption/compute_motion_score.py +186 -0
  27. cogvideox/video_caption/compute_text_score.py +214 -0
  28. cogvideox/video_caption/compute_video_quality.py +201 -0
  29. cogvideox/video_caption/cutscene_detect.py +97 -0
  30. cogvideox/video_caption/filter_meta_train.py +88 -0
  31. cogvideox/video_caption/package_patches/easyocr_detection_patched.py +114 -0
  32. cogvideox/video_caption/package_patches/vila_siglip_encoder_patched.py +42 -0
  33. cogvideox/video_caption/prompt/beautiful_prompt.txt +9 -0
  34. cogvideox/video_caption/prompt/rewrite.txt +9 -0
  35. cogvideox/video_caption/requirements.txt +9 -0
  36. cogvideox/video_caption/scripts/stage_1_video_splitting.sh +39 -0
  37. cogvideox/video_caption/scripts/stage_2_video_filtering.sh +41 -0
  38. cogvideox/video_caption/scripts/stage_3_video_recaptioning.sh +52 -0
  39. cogvideox/video_caption/utils/filter.py +162 -0
  40. cogvideox/video_caption/utils/gather_jsonl.py +55 -0
  41. cogvideox/video_caption/utils/get_meta_file.py +74 -0
  42. cogvideox/video_caption/utils/image_evaluator.py +248 -0
  43. cogvideox/video_caption/utils/logger.py +36 -0
  44. cogvideox/video_caption/utils/longclip/README.md +19 -0
  45. cogvideox/video_caption/utils/longclip/__init__.py +1 -0
  46. cogvideox/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz +3 -0
  47. cogvideox/video_caption/utils/longclip/longclip.py +353 -0
  48. cogvideox/video_caption/utils/longclip/model_longclip.py +471 -0
  49. cogvideox/video_caption/utils/longclip/simple_tokenizer.py +132 -0
  50. cogvideox/video_caption/utils/siglip_v2_5.py +127 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ models*
3
+ output*
4
+ logs*
5
+ taming*
6
+ samples*
7
+ datasets*
8
+ asset*
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/#use-with-ide
117
+ .pdm.toml
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM runpod/pytorch:2.2.1-py3.10-cuda12.1.1-devel-ubuntu22.04
2
+ WORKDIR /content
3
+ ENV PATH="/home/zebraslive/.local/bin:${PATH}"
4
+
5
+ RUN adduser --disabled-password --gecos '' zebraslive && \
6
+ adduser zebraslive sudo && \
7
+ echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \
8
+ chown -R zebraslive:zebraslive /content && \
9
+ chmod -R 777 /content && \
10
+ chown -R zebraslive:zebraslive /home && \
11
+ chmod -R 777 /home && \
12
+ apt update -y && add-apt-repository -y ppa:git-core/ppa && apt update -y && apt install -y aria2 git git-lfs unzip ffmpeg
13
+
14
+ USER zebraslive
15
+
16
+ RUN pip install -q torch==2.4.0+cu121 torchvision==0.19.0+cu121 torchaudio==2.4.0+cu121 torchtext==0.18.0 torchdata==0.8.0 --extra-index-url https://download.pytorch.org/whl/cu121 \
17
+ tqdm==4.66.5 numpy==1.26.3 imageio==2.35.1 imageio-ffmpeg==0.5.1 xformers==0.0.27.post2 diffusers==0.30.3 moviepy==1.0.3 transformers==4.44.2 accelerate==0.33.0 sentencepiece==0.2.0 pillow==9.5.0 runpod && \
18
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/scheduler/scheduler_config.json -d /content/model/scheduler -o scheduler_config.json && \
19
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/text_encoder/config.json -d /content/model/text_encoder -o config.json && \
20
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/resolve/main/text_encoder/model-00001-of-00002.safetensors -d /content/model/text_encoder -o model-00001-of-00002.safetensors && \
21
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/resolve/main/text_encoder/model-00002-of-00002.safetensors -d /content/model/text_encoder -o model-00002-of-00002.safetensors && \
22
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/text_encoder/model.safetensors.index.json -d /content/model/text_encoder -o model.safetensors.index.json && \
23
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/tokenizer/added_tokens.json -d /content/model/tokenizer -o added_tokens.json && \
24
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/tokenizer/special_tokens_map.json -d /content/model/tokenizer -o special_tokens_map.json && \
25
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/resolve/main/tokenizer/spiece.model -d /content/model/tokenizer -o spiece.model && \
26
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/tokenizer/tokenizer_config.json -d /content/model/tokenizer -o tokenizer_config.json && \
27
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/transformer/config.json -d /content/model/transformer -o config.json && \
28
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/resolve/main/transformer/diffusion_pytorch_model-00001-of-00003.safetensors -d /content/model/transformer -o diffusion_pytorch_model-00001-of-00003.safetensors && \
29
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/resolve/main/transformer/diffusion_pytorch_model-00002-of-00003.safetensors -d /content/model/transformer -o diffusion_pytorch_model-00002-of-00003.safetensors && \
30
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/resolve/main/transformer/diffusion_pytorch_model-00003-of-00003.safetensors -d /content/model/transformer -o diffusion_pytorch_model-00003-of-00003.safetensors && \
31
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/transformer/diffusion_pytorch_model.safetensors.index.json -d /content/model/transformer -o diffusion_pytorch_model.safetensors.index.json && \
32
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/vae/config.json -d /content/model/vae -o config.json && \
33
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/resolve/main/vae/diffusion_pytorch_model.safetensors -d /content/model/vae -o diffusion_pytorch_model.safetensors && \
34
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/configuration.json -d /content/model -o configuration.json && \
35
+ aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP/raw/main/model_index.json -d /content/model -o model_index.json
36
+
37
+ COPY ./worker_runpod.py /content/worker_runpod.py
38
+ COPY ./cogvideox /content/cogvideox
39
+ COPY ./asset /content/asset
40
+ COPY ./config /content/config
41
+ COPY ./datasets /content/datasets
42
+ COPY ./reports /content/reports
43
+ COPY ./requirements.txt /content/requirements.txt
44
+ RUN pip install -r /content/requirements.txt
45
+ WORKDIR /content
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX-Fun
2
+
3
+ 😊 Welcome!
4
+
5
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/alibaba-pai/CogVideoX-Fun-5b)
6
+
7
+ English | [简体中文](./README_zh-CN.md)
8
+
9
+ # Table of Contents
10
+ - [Table of Contents](#table-of-contents)
11
+ - [Introduction](#introduction)
12
+ - [Quick Start](#quick-start)
13
+ - [Video Result](#video-result)
14
+ - [How to use](#how-to-use)
15
+ - [Model zoo](#model-zoo)
16
+ - [TODO List](#todo-list)
17
+ - [Reference](#reference)
18
+ - [License](#license)
19
+
20
+ # Introduction
21
+ CogVideoX-Fun is a modified pipeline based on the CogVideoX structure, designed to provide more flexibility in generation. It can be used to create AI images and videos, as well as to train baseline models and Lora models for Diffusion Transformer. We support predictions directly from the already trained CogVideoX-Fun model, allowing the generation of videos at different resolutions, approximately 6 seconds long with 8 fps (1 to 49 frames). Users can also train their own baseline models and Lora models to achieve certain style transformations.
22
+
23
+ We will support quick pull-ups from different platforms, refer to [Quick Start](#quick-start).
24
+
25
+ What's New:
26
+ - CogVideoX-Fun Control is now supported in diffusers. Thanks to [a-r-r-o-w](https://github.com/a-r-r-o-w) who contributed the support in this [PR](https://github.com/huggingface/diffusers/pull/9671). Check out the [docs](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) to know more. [ 2024.10.16 ]
27
+ - Retrain the i2v model and add noise to increase the motion amplitude of the video. Upload the control model training code and control model. [ 2024.09.29 ]
28
+ - Create code! Now supporting Windows and Linux. Supports 2b and 5b models. Supports video generation at any resolution from 256x256x49 to 1024x1024x49. [ 2024.09.18 ]
29
+
30
+ Function:
31
+ - [Data Preprocessing](#data-preprocess)
32
+ - [Train DiT](#dit-train)
33
+ - [Video Generation](#video-gen)
34
+
35
+ Our UI interface is as follows:
36
+ ![ui](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/ui.jpg)
37
+
38
+ # Quick Start
39
+ ### 1. Cloud usage: AliyunDSW/Docker
40
+ #### a. From AliyunDSW
41
+ DSW has free GPU time, which can be applied once by a user and is valid for 3 months after applying.
42
+
43
+ Aliyun provide free GPU time in [Freetier](https://free.aliyun.com/?product=9602825&crowd=enterprise&spm=5176.28055625.J_5831864660.1.e939154aRgha4e&scm=20140722.M_9974135.P_110.MO_1806-ID_9974135-MID_9974135-CID_30683-ST_8512-V_1), get it and use in Aliyun PAI-DSW to start CogVideoX-Fun within 5min!
44
+
45
+ [![DSW Notebook](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/dsw.png)](https://gallery.pai-ml.com/#/preview/deepLearning/cv/cogvideox_fun)
46
+
47
+ #### b. From ComfyUI
48
+ Our ComfyUI is as follows, please refer to [ComfyUI README](comfyui/README.md) for details.
49
+ ![workflow graph](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/cogvideoxfunv1_workflow_i2v.jpg)
50
+
51
+ #### c. From docker
52
+ If you are using docker, please make sure that the graphics card driver and CUDA environment have been installed correctly in your machine.
53
+
54
+ Then execute the following commands in this way:
55
+
56
+ ```
57
+ # pull image
58
+ docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
59
+
60
+ # enter image
61
+ docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
62
+
63
+ # clone code
64
+ git clone https://github.com/aigc-apps/CogVideoX-Fun.git
65
+
66
+ # enter CogVideoX-Fun's dir
67
+ cd CogVideoX-Fun
68
+
69
+ # download weights
70
+ mkdir models/Diffusion_Transformer
71
+ mkdir models/Personalized_Model
72
+
73
+ wget https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP.tar.gz -O models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP.tar.gz
74
+
75
+ cd models/Diffusion_Transformer/
76
+ tar -xvf CogVideoX-Fun-V1.1-2b-InP.tar.gz
77
+ cd ../../
78
+ ```
79
+
80
+ ### 2. Local install: Environment Check/Downloading/Installation
81
+ #### a. Environment Check
82
+ We have verified CogVideoX-Fun execution on the following environment:
83
+
84
+ The detailed of Windows:
85
+ - OS: Windows 10
86
+ - python: python3.10 & python3.11
87
+ - pytorch: torch2.2.0
88
+ - CUDA: 11.8 & 12.1
89
+ - CUDNN: 8+
90
+ - GPU: Nvidia-3060 12G & Nvidia-3090 24G
91
+
92
+ The detailed of Linux:
93
+ - OS: Ubuntu 20.04, CentOS
94
+ - python: python3.10 & python3.11
95
+ - pytorch: torch2.2.0
96
+ - CUDA: 11.8 & 12.1
97
+ - CUDNN: 8+
98
+ - GPU:Nvidia-V100 16G & Nvidia-A10 24G & Nvidia-A100 40G & Nvidia-A100 80G
99
+
100
+ We need about 60GB available on disk (for saving weights), please check!
101
+
102
+ #### b. Weights
103
+ We'd better place the [weights](#model-zoo) along the specified path:
104
+
105
+ ```
106
+ 📦 models/
107
+ ├── 📂 Diffusion_Transformer/
108
+ │ ├── 📂 CogVideoX-Fun-V1.1-2b-InP/
109
+ │ └── 📂 CogVideoX-Fun-V1.1-5b-InP/
110
+ ├── 📂 Personalized_Model/
111
+ │ └── your trained trainformer model / your trained lora model (for UI load)
112
+ ```
113
+
114
+ # Video Result
115
+ The results displayed are all based on image.
116
+
117
+ ### CogVideoX-Fun-V1.1-5B
118
+
119
+ Resolution-1024
120
+
121
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
122
+ <tr>
123
+ <td>
124
+ <video src="https://github.com/user-attachments/assets/34e7ec8f-293e-4655-bb14-5e1ee476f788" width="100%" controls autoplay loop></video>
125
+ </td>
126
+ <td>
127
+ <video src="https://github.com/user-attachments/assets/7809c64f-eb8c-48a9-8bdc-ca9261fd5434" width="100%" controls autoplay loop></video>
128
+ </td>
129
+ <td>
130
+ <video src="https://github.com/user-attachments/assets/8e76aaa4-c602-44ac-bcb4-8b24b72c386c" width="100%" controls autoplay loop></video>
131
+ </td>
132
+ <td>
133
+ <video src="https://github.com/user-attachments/assets/19dba894-7c35-4f25-b15c-384167ab3b03" width="100%" controls autoplay loop></video>
134
+ </td>
135
+ </tr>
136
+ </table>
137
+
138
+
139
+ Resolution-768
140
+
141
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
142
+ <tr>
143
+ <td>
144
+ <video src="https://github.com/user-attachments/assets/0bc339b9-455b-44fd-8917-80272d702737" width="100%" controls autoplay loop></video>
145
+ </td>
146
+ <td>
147
+ <video src="https://github.com/user-attachments/assets/70a043b9-6721-4bd9-be47-78b7ec5c27e9" width="100%" controls autoplay loop></video>
148
+ </td>
149
+ <td>
150
+ <video src="https://github.com/user-attachments/assets/d5dd6c09-14f3-40f8-8b6d-91e26519b8ac" width="100%" controls autoplay loop></video>
151
+ </td>
152
+ <td>
153
+ <video src="https://github.com/user-attachments/assets/9327e8bc-4f17-46b0-b50d-38c250a9483a" width="100%" controls autoplay loop></video>
154
+ </td>
155
+ </tr>
156
+ </table>
157
+
158
+ Resolution-512
159
+
160
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
161
+ <tr>
162
+ <td>
163
+ <video src="https://github.com/user-attachments/assets/ef407030-8062-454d-aba3-131c21e6b58c" width="100%" controls autoplay loop></video>
164
+ </td>
165
+ <td>
166
+ <video src="https://github.com/user-attachments/assets/7610f49e-38b6-4214-aa48-723ae4d1b07e" width="100%" controls autoplay loop></video>
167
+ </td>
168
+ <td>
169
+ <video src="https://github.com/user-attachments/assets/1fff0567-1e15-415c-941e-53ee8ae2c841" width="100%" controls autoplay loop></video>
170
+ </td>
171
+ <td>
172
+ <video src="https://github.com/user-attachments/assets/bcec48da-b91b-43a0-9d50-cf026e00fa4f" width="100%" controls autoplay loop></video>
173
+ </td>
174
+ </tr>
175
+ </table>
176
+
177
+ ### CogVideoX-Fun-V1.1-5B-Pose
178
+
179
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
180
+ <tr>
181
+ <td>
182
+ Resolution-512
183
+ </td>
184
+ <td>
185
+ Resolution-768
186
+ </td>
187
+ <td>
188
+ Resolution-1024
189
+ </td>
190
+ <tr>
191
+ <td>
192
+ <video src="https://github.com/user-attachments/assets/a746df51-9eb7-4446-bee5-2ee30285c143" width="100%" controls autoplay loop></video>
193
+ </td>
194
+ <td>
195
+ <video src="https://github.com/user-attachments/assets/db295245-e6aa-43be-8c81-32cb411f1473" width="100%" controls autoplay loop></video>
196
+ </td>
197
+ <td>
198
+ <video src="https://github.com/user-attachments/assets/ec9875b2-fde0-48e1-ab7e-490cee51ef40" width="100%" controls autoplay loop></video>
199
+ </td>
200
+ </tr>
201
+ </table>
202
+
203
+ ### CogVideoX-Fun-V1.1-2B
204
+
205
+ Resolution-768
206
+
207
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
208
+ <tr>
209
+ <td>
210
+ <video src="https://github.com/user-attachments/assets/03235dea-980e-4fc5-9c41-e40a5bc1b6d0" width="100%" controls autoplay loop></video>
211
+ </td>
212
+ <td>
213
+ <video src="https://github.com/user-attachments/assets/f7302648-5017-47db-bdeb-4d893e620b37" width="100%" controls autoplay loop></video>
214
+ </td>
215
+ <td>
216
+ <video src="https://github.com/user-attachments/assets/cbadf411-28fa-4b87-813d-da63ff481904" width="100%" controls autoplay loop></video>
217
+ </td>
218
+ <td>
219
+ <video src="https://github.com/user-attachments/assets/87cc9d0b-b6fe-4d2d-b447-174513d169ab" width="100%" controls autoplay loop></video>
220
+ </td>
221
+ </tr>
222
+ </table>
223
+
224
+ ### CogVideoX-Fun-V1.1-2B-Pose
225
+
226
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
227
+ <tr>
228
+ <td>
229
+ Resolution-512
230
+ </td>
231
+ <td>
232
+ Resolution-768
233
+ </td>
234
+ <td>
235
+ Resolution-1024
236
+ </td>
237
+ <tr>
238
+ <td>
239
+ <video src="https://github.com/user-attachments/assets/487bcd7b-1b7f-4bb4-95b5-96a6b6548b3e" width="100%" controls autoplay loop></video>
240
+ </td>
241
+ <td>
242
+ <video src="https://github.com/user-attachments/assets/2710fd18-8489-46e4-8086-c237309ae7f6" width="100%" controls autoplay loop></video>
243
+ </td>
244
+ <td>
245
+ <video src="https://github.com/user-attachments/assets/b79513db-7747-4512-b86c-94f9ca447fe2" width="100%" controls autoplay loop></video>
246
+ </td>
247
+ </tr>
248
+ </table>
249
+
250
+ # How to use
251
+
252
+ <h3 id="video-gen">1. Inference </h3>
253
+
254
+ #### a. Using Python Code
255
+ - Step 1: Download the corresponding [weights](#model-zoo) and place them in the models folder.
256
+ - Step 2: Modify prompt, neg_prompt, guidance_scale, and seed in the predict_t2v.py file.
257
+ - Step 3: Run the predict_t2v.py file, wait for the generated results, and save the results in the samples/cogvideox-fun-videos-t2v folder.
258
+ - Step 4: If you want to combine other backbones you have trained with Lora, modify the predict_t2v.py and Lora_path in predict_t2v.py depending on the situation.
259
+
260
+ #### b. Using webui
261
+ - Step 1: Download the corresponding [weights](#model-zoo) and place them in the models folder.
262
+ - Step 2: Run the app.py file to enter the graph page.
263
+ - Step 3: Select the generated model based on the page, fill in prompt, neg_prompt, guidance_scale, and seed, click on generate, wait for the generated result, and save the result in the samples folder.
264
+
265
+ #### c. From ComfyUI
266
+ Please refer to [ComfyUI README](comfyui/README.md) for details.
267
+
268
+ ### 2. Model Training
269
+ A complete CogVideoX-Fun training pipeline should include data preprocessing, and Video DiT training.
270
+
271
+ <h4 id="data-preprocess">a. data preprocessing</h4>
272
+
273
+ We have provided a simple demo of training the Lora model through image data, which can be found in the [wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora) for details.
274
+
275
+ A complete data preprocessing link for long video segmentation, cleaning, and description can refer to [README](cogvideox/video_caption/README.md) in the video captions section.
276
+
277
+ If you want to train a text to image and video generation model. You need to arrange the dataset in this format.
278
+
279
+ ```
280
+ 📦 project/
281
+ ├── 📂 datasets/
282
+ │ ├── 📂 internal_datasets/
283
+ │ ├── 📂 train/
284
+ │ │ ├── 📄 00000001.mp4
285
+ │ │ ├── 📄 00000002.jpg
286
+ │ │ └── 📄 .....
287
+ │ └── 📄 json_of_internal_datasets.json
288
+ ```
289
+
290
+ The json_of_internal_datasets.json is a standard JSON file. The file_path in the json can to be set as relative path, as shown in below:
291
+ ```json
292
+ [
293
+ {
294
+ "file_path": "train/00000001.mp4",
295
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
296
+ "type": "video"
297
+ },
298
+ {
299
+ "file_path": "train/00000002.jpg",
300
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
301
+ "type": "image"
302
+ },
303
+ .....
304
+ ]
305
+ ```
306
+
307
+ You can also set the path as absolute path as follow:
308
+ ```json
309
+ [
310
+ {
311
+ "file_path": "/mnt/data/videos/00000001.mp4",
312
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
313
+ "type": "video"
314
+ },
315
+ {
316
+ "file_path": "/mnt/data/train/00000001.jpg",
317
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
318
+ "type": "image"
319
+ },
320
+ .....
321
+ ]
322
+ ```
323
+
324
+ <h4 id="dit-train">b. Video DiT training </h4>
325
+
326
+ If the data format is relative path during data preprocessing, please set ```scripts/train.sh``` as follow.
327
+ ```
328
+ export DATASET_NAME="datasets/internal_datasets/"
329
+ export DATASET_META_NAME="datasets/internal_datasets/json_of_internal_datasets.json"
330
+ ```
331
+
332
+ If the data format is absolute path during data preprocessing, please set ```scripts/train.sh``` as follow.
333
+ ```
334
+ export DATASET_NAME=""
335
+ export DATASET_META_NAME="/mnt/data/json_of_internal_datasets.json"
336
+ ```
337
+
338
+ Then, we run scripts/train.sh.
339
+ ```sh
340
+ sh scripts/train.sh
341
+ ```
342
+
343
+ For details on setting some parameters, please refer to [Readme Train](scripts/README_TRAIN.md), [Readme Lora](scripts/README_TRAIN_LORA.md) and [Readme Control](scripts/README_TRAIN_CONTROL.md).
344
+
345
+
346
+ # Model zoo
347
+
348
+ V1.1:
349
+
350
+ | 名称 | 存储空间 | Hugging Face | Model Scope | 描述 |
351
+ |--|--|--|--|--|
352
+ | CogVideoX-Fun-V1.1-2b-InP.tar.gz | Before extraction:9.7 GB \/ After extraction: 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-2b-InP) | Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. Noise has been added to the reference image, and the amplitude of motion is greater compared to V1.0. |
353
+ | CogVideoX-Fun-V1.1-5b-InP.tar.gz | Before extraction:16.0 GB \/ After extraction: 20.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-5b-InP) | Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. Noise has been added to the reference image, and the amplitude of motion is greater compared to V1.0. |
354
+ | CogVideoX-Fun-V1.1-2b-Pose.tar.gz | Before extraction:9.7 GB \/ After extraction: 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-2b-Pose) | Our official pose-control video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second.|
355
+ | CogVideoX-Fun-V1.1-5b-Pose.tar.gz | Before extraction:16.0 GB \/ After extraction: 20.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-5b-Pose) | Our official pose-control video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second.|
356
+
357
+ V1.0:
358
+
359
+ | Name | Storage Space | Hugging Face | Model Scope | Description |
360
+ |--|--|--|--|--|
361
+ | CogVideoX-Fun-2b-InP.tar.gz | Before extraction:9.7 GB \/ After extraction: 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-2b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-2b-InP) | Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. |
362
+ | CogVideoX-Fun-5b-InP.tar.gz | Before extraction:16.0 GB \/ After extraction: 20.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-5b-InP)| [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-5b-InP)| Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. |
363
+
364
+ # TODO List
365
+ - Support Chinese.
366
+
367
+ # Reference
368
+ - CogVideo: https://github.com/THUDM/CogVideo/
369
+ - EasyAnimate: https://github.com/aigc-apps/EasyAnimate
370
+
371
+ # License
372
+ This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
373
+
374
+ The CogVideoX-2B model (including its corresponding Transformers module and VAE module) is released under the [Apache 2.0 License](LICENSE).
375
+
376
+ The CogVideoX-5B model (Transformers module) is released under the [CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE).
README_zh-CN.md ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX-Fun
2
+
3
+ 😊 Welcome!
4
+
5
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/alibaba-pai/CogVideoX-Fun-5b)
6
+
7
+ [English](./README.md) | 简体中文
8
+
9
+ # 目录
10
+ - [目录](#目录)
11
+ - [简介](#简介)
12
+ - [快速启动](#快速启动)
13
+ - [视频作品](#视频作品)
14
+ - [如何使用](#如何使用)
15
+ - [模型地址](#模型地址)
16
+ - [未来计划](#未来计划)
17
+ - [参考文献](#参考文献)
18
+ - [许可证](#许可证)
19
+
20
+ # 简介
21
+ CogVideoX-Fun是一个基于CogVideoX结构修改后的的pipeline,是一个生成条件更自由的CogVideoX,可用于生成AI图片与视频、训练Diffusion Transformer的基线模型与Lora模型,我们支持从已经训练好的CogVideoX-Fun模型直接进行预测,生成不同分辨率,6秒左右、fps8的视频(1 ~ 49帧),也支持用户训练自己的基线模型与Lora模型,进行一定的风格变换。
22
+
23
+ 我们会逐渐支持从不同平台快速启动,请参阅 [快速启动](#快速启动)。
24
+
25
+ 新特性:
26
+ - CogVideoX-Fun Control现在在diffusers中得到了支持。感谢 [a-r-r-o-w](https://github.com/a-r-r-o-w)在这个 [PR](https://github.com/huggingface/diffusers/pull/9671)中贡献了支持。查看[文档](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox)以了解更多信息。[2024.10.16]
27
+ - 重新训练i2v模型,添加Noise,使得视频的运动幅度更大。上传控制模型训练代码与Control模型。[ 2024.09.29 ]
28
+ - 创建代码!现在支持 Windows 和 Linux。支持2b与5b最大256x256x49到1024x1024x49的任意分辨率的视频生成。[ 2024.09.18 ]
29
+
30
+ 功能概览:
31
+ - [数据预处理](#data-preprocess)
32
+ - [训练DiT](#dit-train)
33
+ - [模型生成](#video-gen)
34
+
35
+ 我们的ui界面如下:
36
+ ![ui](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/ui.jpg)
37
+
38
+ # 快速启动
39
+ ### 1. 云使用: AliyunDSW/Docker
40
+ #### a. 通过阿里云 DSW
41
+ DSW 有免费 GPU 时间,用户可申请一次,申请后3个月内有效。
42
+
43
+ 阿里云在[Freetier](https://free.aliyun.com/?product=9602825&crowd=enterprise&spm=5176.28055625.J_5831864660.1.e939154aRgha4e&scm=20140722.M_9974135.P_110.MO_1806-ID_9974135-MID_9974135-CID_30683-ST_8512-V_1)提供免费GPU时间,获取并在阿里云PAI-DSW中使用,5分钟内即可启动CogVideoX-Fun。
44
+
45
+ [![DSW Notebook](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/dsw.png)](https://gallery.pai-ml.com/#/preview/deepLearning/cv/cogvideox_fun)
46
+
47
+ #### b. 通过ComfyUI
48
+ 我们的ComfyUI界面如下,具体查看[ComfyUI README](comfyui/README.md)。
49
+ ![workflow graph](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/cogvideoxfunv1_workflow_i2v.jpg)
50
+
51
+ #### c. 通过docker
52
+ 使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令:
53
+
54
+ ```
55
+ # pull image
56
+ docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
57
+
58
+ # enter image
59
+ docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
60
+
61
+ # clone code
62
+ git clone https://github.com/aigc-apps/CogVideoX-Fun.git
63
+
64
+ # enter CogVideoX-Fun's dir
65
+ cd CogVideoX-Fun
66
+
67
+ # download weights
68
+ mkdir models/Diffusion_Transformer
69
+ mkdir models/Personalized_Model
70
+
71
+ wget https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP.tar.gz -O models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP.tar.gz
72
+
73
+ cd models/Diffusion_Transformer/
74
+ tar -xvf CogVideoX-Fun-V1.1-2b-InP.tar.gz
75
+ cd ../../
76
+ ```
77
+
78
+ ### 2. 本地安装: 环境检查/下载/安装
79
+ #### a. 环境检查
80
+ 我们已验证CogVideoX-Fun可在以下环境中执行:
81
+
82
+ Windows 的详细信息:
83
+ - 操作系统 Windows 10
84
+ - python: python3.10 & python3.11
85
+ - pytorch: torch2.2.0
86
+ - CUDA: 11.8 & 12.1
87
+ - CUDNN: 8+
88
+ - GPU: Nvidia-3060 12G & Nvidia-3090 24G
89
+
90
+ Linux 的详细信息:
91
+ - 操作系统 Ubuntu 20.04, CentOS
92
+ - python: python3.10 & python3.11
93
+ - pytorch: torch2.2.0
94
+ - CUDA: 11.8 & 12.1
95
+ - CUDNN: 8+
96
+ - GPU:Nvidia-V100 16G & Nvidia-A10 24G & Nvidia-A100 40G & Nvidia-A100 80G
97
+
98
+ 我们需要大约 60GB 的可用磁盘空间,请检查!
99
+
100
+ #### b. 权重放置
101
+ 我们最好将[权重](#model-zoo)按照指定路径进行放置:
102
+
103
+ ```
104
+ 📦 models/
105
+ ├── 📂 Diffusion_Transformer/
106
+ │ ├── 📂 CogVideoX-Fun-V1.1-2b-InP/
107
+ │ └── 📂 CogVideoX-Fun-V1.1-5b-InP/
108
+ ├── 📂 Personalized_Model/
109
+ │ └── your trained trainformer model / your trained lora model (for UI load)
110
+ ```
111
+
112
+ # 视频作品
113
+ 所展示的结果都是图生视频获得。
114
+
115
+ ### CogVideoX-Fun-V1.1-5B
116
+
117
+ Resolution-1024
118
+
119
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
120
+ <tr>
121
+ <td>
122
+ <video src="https://github.com/user-attachments/assets/34e7ec8f-293e-4655-bb14-5e1ee476f788" width="100%" controls autoplay loop></video>
123
+ </td>
124
+ <td>
125
+ <video src="https://github.com/user-attachments/assets/7809c64f-eb8c-48a9-8bdc-ca9261fd5434" width="100%" controls autoplay loop></video>
126
+ </td>
127
+ <td>
128
+ <video src="https://github.com/user-attachments/assets/8e76aaa4-c602-44ac-bcb4-8b24b72c386c" width="100%" controls autoplay loop></video>
129
+ </td>
130
+ <td>
131
+ <video src="https://github.com/user-attachments/assets/19dba894-7c35-4f25-b15c-384167ab3b03" width="100%" controls autoplay loop></video>
132
+ </td>
133
+ </tr>
134
+ </table>
135
+
136
+
137
+ Resolution-768
138
+
139
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
140
+ <tr>
141
+ <td>
142
+ <video src="https://github.com/user-attachments/assets/0bc339b9-455b-44fd-8917-80272d702737" width="100%" controls autoplay loop></video>
143
+ </td>
144
+ <td>
145
+ <video src="https://github.com/user-attachments/assets/70a043b9-6721-4bd9-be47-78b7ec5c27e9" width="100%" controls autoplay loop></video>
146
+ </td>
147
+ <td>
148
+ <video src="https://github.com/user-attachments/assets/d5dd6c09-14f3-40f8-8b6d-91e26519b8ac" width="100%" controls autoplay loop></video>
149
+ </td>
150
+ <td>
151
+ <video src="https://github.com/user-attachments/assets/9327e8bc-4f17-46b0-b50d-38c250a9483a" width="100%" controls autoplay loop></video>
152
+ </td>
153
+ </tr>
154
+ </table>
155
+
156
+ Resolution-512
157
+
158
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
159
+ <tr>
160
+ <td>
161
+ <video src="https://github.com/user-attachments/assets/ef407030-8062-454d-aba3-131c21e6b58c" width="100%" controls autoplay loop></video>
162
+ </td>
163
+ <td>
164
+ <video src="https://github.com/user-attachments/assets/7610f49e-38b6-4214-aa48-723ae4d1b07e" width="100%" controls autoplay loop></video>
165
+ </td>
166
+ <td>
167
+ <video src="https://github.com/user-attachments/assets/1fff0567-1e15-415c-941e-53ee8ae2c841" width="100%" controls autoplay loop></video>
168
+ </td>
169
+ <td>
170
+ <video src="https://github.com/user-attachments/assets/bcec48da-b91b-43a0-9d50-cf026e00fa4f" width="100%" controls autoplay loop></video>
171
+ </td>
172
+ </tr>
173
+ </table>
174
+
175
+ ### CogVideoX-Fun-V1.1-5B-Pose
176
+
177
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
178
+ <tr>
179
+ <td>
180
+ Resolution-512
181
+ </td>
182
+ <td>
183
+ Resolution-768
184
+ </td>
185
+ <td>
186
+ Resolution-1024
187
+ </td>
188
+ <tr>
189
+ <td>
190
+ <video src="https://github.com/user-attachments/assets/a746df51-9eb7-4446-bee5-2ee30285c143" width="100%" controls autoplay loop></video>
191
+ </td>
192
+ <td>
193
+ <video src="https://github.com/user-attachments/assets/db295245-e6aa-43be-8c81-32cb411f1473" width="100%" controls autoplay loop></video>
194
+ </td>
195
+ <td>
196
+ <video src="https://github.com/user-attachments/assets/ec9875b2-fde0-48e1-ab7e-490cee51ef40" width="100%" controls autoplay loop></video>
197
+ </td>
198
+ </tr>
199
+ </table>
200
+
201
+ ### CogVideoX-Fun-V1.1-2B
202
+
203
+ Resolution-768
204
+
205
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
206
+ <tr>
207
+ <td>
208
+ <video src="https://github.com/user-attachments/assets/03235dea-980e-4fc5-9c41-e40a5bc1b6d0" width="100%" controls autoplay loop></video>
209
+ </td>
210
+ <td>
211
+ <video src="https://github.com/user-attachments/assets/f7302648-5017-47db-bdeb-4d893e620b37" width="100%" controls autoplay loop></video>
212
+ </td>
213
+ <td>
214
+ <video src="https://github.com/user-attachments/assets/cbadf411-28fa-4b87-813d-da63ff481904" width="100%" controls autoplay loop></video>
215
+ </td>
216
+ <td>
217
+ <video src="https://github.com/user-attachments/assets/87cc9d0b-b6fe-4d2d-b447-174513d169ab" width="100%" controls autoplay loop></video>
218
+ </td>
219
+ </tr>
220
+ </table>
221
+
222
+ ### CogVideoX-Fun-V1.1-2B-Pose
223
+
224
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
225
+ <tr>
226
+ <td>
227
+ Resolution-512
228
+ </td>
229
+ <td>
230
+ Resolution-768
231
+ </td>
232
+ <td>
233
+ Resolution-1024
234
+ </td>
235
+ <tr>
236
+ <td>
237
+ <video src="https://github.com/user-attachments/assets/487bcd7b-1b7f-4bb4-95b5-96a6b6548b3e" width="100%" controls autoplay loop></video>
238
+ </td>
239
+ <td>
240
+ <video src="https://github.com/user-attachments/assets/2710fd18-8489-46e4-8086-c237309ae7f6" width="100%" controls autoplay loop></video>
241
+ </td>
242
+ <td>
243
+ <video src="https://github.com/user-attachments/assets/b79513db-7747-4512-b86c-94f9ca447fe2" width="100%" controls autoplay loop></video>
244
+ </td>
245
+ </tr>
246
+ </table>
247
+
248
+ # 如何使用
249
+
250
+ <h3 id="video-gen">1. 生成 </h3>
251
+
252
+ #### a. 视频生成
253
+ ##### i、运行python文件
254
+ - 步骤1:下载对应[权重](#model-zoo)放入models文件夹。
255
+ - 步骤2:在predict_t2v.py文件中修改prompt、neg_prompt、guidance_scale和seed。
256
+ - 步骤3:运行predict_t2v.py文件,等待生成结果,结果保存在samples/cogvideox-fun-videos-t2v文件夹中。
257
+ - 步骤4:如果想结合自己训练的其他backbone与Lora,则看情况修改predict_t2v.py中的predict_t2v.py和lora_path。
258
+
259
+ ##### ii、通过ui界面
260
+ - 步骤1:下载对应[权重](#model-zoo)放入models文件夹。
261
+ - 步骤2:运行app.py文件,进入gradio页面。
262
+ - 步骤3:根据页面选择生成模型,填入prompt、neg_prompt、guidance_scale和seed等,点击生成,等待生成结果,结果保存在sample文件夹中。
263
+
264
+ ##### iii、通过comfyui
265
+ 具体查看[ComfyUI README](comfyui/README.md)。
266
+
267
+ ### 2. 模型训练
268
+ 一个完整的CogVideoX-Fun训练链路应该包括数据预处理和Video DiT训练。
269
+
270
+ <h4 id="data-preprocess">a.数据预处理</h4>
271
+ 我们给出了一个简单的demo通过图片数据训练lora模型,详情可以查看[wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora)。
272
+
273
+ 一个完整的长视频切分、清洗、描述的数据预处理链路可以参考video caption部分的[README](cogvideox/video_caption/README.md)进行。
274
+
275
+ 如果期望训练一个文生图视频的生成模型,您需要以这种格式排列数据集。
276
+ ```
277
+ 📦 project/
278
+ ├── 📂 datasets/
279
+ │ ├── 📂 internal_datasets/
280
+ │ ├── 📂 train/
281
+ │ │ ├── 📄 00000001.mp4
282
+ │ │ ├── 📄 00000002.jpg
283
+ │ │ └── 📄 .....
284
+ │ └── 📄 json_of_internal_datasets.json
285
+ ```
286
+
287
+ json_of_internal_datasets.json是一个标准的json文件。json中的file_path可以被设置为相对路径,如下所示:
288
+ ```json
289
+ [
290
+ {
291
+ "file_path": "train/00000001.mp4",
292
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
293
+ "type": "video"
294
+ },
295
+ {
296
+ "file_path": "train/00000002.jpg",
297
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
298
+ "type": "image"
299
+ },
300
+ .....
301
+ ]
302
+ ```
303
+
304
+ 你也可以将路径设置为绝对路径:
305
+ ```json
306
+ [
307
+ {
308
+ "file_path": "/mnt/data/videos/00000001.mp4",
309
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
310
+ "type": "video"
311
+ },
312
+ {
313
+ "file_path": "/mnt/data/train/00000001.jpg",
314
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
315
+ "type": "image"
316
+ },
317
+ .....
318
+ ]
319
+ ```
320
+ <h4 id="dit-train">b. Video DiT训练 </h4>
321
+
322
+ 如果数据预处理时,数据的格式为相对路径,则进入scripts/train.sh进行如下设置。
323
+ ```
324
+ export DATASET_NAME="datasets/internal_datasets/"
325
+ export DATASET_META_NAME="datasets/internal_datasets/json_of_internal_datasets.json"
326
+
327
+ ...
328
+
329
+ train_data_format="normal"
330
+ ```
331
+
332
+ 如果数据的格式为绝对路径,则进入scripts/train.sh进行如下设置。
333
+ ```
334
+ export DATASET_NAME=""
335
+ export DATASET_META_NAME="/mnt/data/json_of_internal_datasets.json"
336
+ ```
337
+
338
+ 最后运行scripts/train.sh。
339
+ ```sh
340
+ sh scripts/train.sh
341
+ ```
342
+
343
+ 关于一些参数的设置细节,可以查看[Readme Train](scripts/README_TRAIN.md)与[Readme Lora](scripts/README_TRAIN_LORA.md)
344
+
345
+ # 模型地址
346
+
347
+ V1.1:
348
+
349
+ | 名称 | 存储空间 | Hugging Face | Model Scope | 描述 |
350
+ |--|--|--|--|--|
351
+ | CogVideoX-Fun-V1.1-2b-InP.tar.gz | 解压前 9.7 GB / 解压后 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-2b-InP) | 官方的图生视频权重。添加了Noise,运动幅度相比于V1.0更大。支持多分辨率(512,768,1024,1280)的视频预测,以49帧、每秒8帧进行训练 |
352
+ | CogVideoX-Fun-V1.1-5b-InP.tar.gz | 解压前 16.0GB / 解压后 20.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-5b-InP) | 官方的图生视频权重。添加了Noise,运动幅度相比于V1.0更大。支持多分辨率(512,768,1024,1280)的视频预测,以49帧、每秒8帧进行训练 |
353
+ | CogVideoX-Fun-V1.1-2b-Pose.tar.gz | 解压前 9.7 GB / 解压后 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-2b-Pose) | 官方的姿态控制生视频权重。支持多分辨率(512,768,1024,1280)的视频预测,以49帧、每秒8帧进行训练 |
354
+ | CogVideoX-Fun-V1.1-5b-Pose.tar.gz | 解压前 16.0GB / 解压后 20.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-5b-Pose) | 官方的姿态控制生视频权重。支持多分辨率(512,768,1024,1280)的视频预测,以49帧、每秒8帧进行训练 |
355
+
356
+ V1.0:
357
+
358
+ | 名称 | 存储空间 | Hugging Face | Model Scope | 描述 |
359
+ |--|--|--|--|--|
360
+ | CogVideoX-Fun-2b-InP.tar.gz | 解压前 9.7 GB / 解压后 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-2b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-2b-InP) | 官方的图生视频权重。支持多分辨率(512,768,1024,1280)的视频预测,以49帧、每秒8帧进行训练 |
361
+ | CogVideoX-Fun-5b-InP.tar.gz | 解压前 16.0GB / 解压后 20.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-5b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-5b-InP) | 官方的图生视频权重。支持多分辨率(512,768,1024,1280)的视频预测,以49帧、每秒8帧进行训练 |
362
+
363
+ # 未来计划
364
+ - 支持中文。
365
+
366
+ # 参考文献
367
+ - CogVideo: https://github.com/THUDM/CogVideo/
368
+ - EasyAnimate: https://github.com/aigc-apps/EasyAnimate
369
+
370
+ # 许可证
371
+ 本项目采用 [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
372
+
373
+ CogVideoX-2B 模型 (包括其对应的Transformers模块,VAE模块) 根据 [Apache 2.0 协议](LICENSE) 许可证发布。
374
+
375
+ CogVideoX-5B 模型(Transformer 模块)在[CogVideoX许可证](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE)下发布.
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .comfyui.comfyui_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2
+
3
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+
4
+ from cogvideox.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api
5
+ from cogvideox.ui.ui import ui_modelscope, ui_eas, ui
6
+
7
+ if __name__ == "__main__":
8
+ # Choose the ui mode
9
+ ui_mode = "normal"
10
+
11
+ # Low gpu memory mode, this is used when the GPU memory is under 16GB
12
+ low_gpu_memory_mode = False
13
+ # Use torch.float16 if GPU does not support torch.bfloat16
14
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
15
+ weight_dtype = torch.bfloat16
16
+
17
+ # Server ip
18
+ server_name = "0.0.0.0"
19
+ server_port = 7860
20
+
21
+ # Params below is used when ui_mode = "modelscope"
22
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
23
+ # "Inpaint" or "Control"
24
+ model_type = "Inpaint"
25
+ # Save dir of this model
26
+ savedir_sample = "samples"
27
+
28
+ if ui_mode == "modelscope":
29
+ demo, controller = ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
30
+ elif ui_mode == "eas":
31
+ demo, controller = ui_eas(model_name, savedir_sample)
32
+ else:
33
+ demo, controller = ui(low_gpu_memory_mode, weight_dtype)
34
+
35
+ # launch gradio
36
+ app, _, _ = demo.queue(status_update_rate=1).launch(
37
+ server_name=server_name,
38
+ server_port=server_port,
39
+ prevent_thread_lock=True
40
+ )
41
+
42
+ # launch api
43
+ infer_forward_api(None, app, controller)
44
+ update_diffusion_transformer_api(None, app, controller)
45
+ update_edition_api(None, app, controller)
46
+
47
+ # not close the python
48
+ while True:
49
+ time.sleep(5)
cogvideox/__init__.py ADDED
File without changes
cogvideox/api/api.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import gc
3
+ import base64
4
+ import torch
5
+ import gradio as gr
6
+ import tempfile
7
+ import hashlib
8
+ import os
9
+
10
+ from fastapi import FastAPI
11
+ from io import BytesIO
12
+ from PIL import Image
13
+
14
+ # Function to encode a file to Base64
15
+ def encode_file_to_base64(file_path):
16
+ with open(file_path, "rb") as file:
17
+ # Encode the data to Base64
18
+ file_base64 = base64.b64encode(file.read())
19
+ return file_base64
20
+
21
+ def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
22
+ @app.post("/cogvideox_fun/update_edition")
23
+ def _update_edition_api(
24
+ datas: dict,
25
+ ):
26
+ edition = datas.get('edition', 'v2')
27
+
28
+ try:
29
+ controller.update_edition(
30
+ edition
31
+ )
32
+ comment = "Success"
33
+ except Exception as e:
34
+ torch.cuda.empty_cache()
35
+ comment = f"Error. error information is {str(e)}"
36
+
37
+ return {"message": comment}
38
+
39
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
40
+ @app.post("/cogvideox_fun/update_diffusion_transformer")
41
+ def _update_diffusion_transformer_api(
42
+ datas: dict,
43
+ ):
44
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
45
+
46
+ try:
47
+ controller.update_diffusion_transformer(
48
+ diffusion_transformer_path
49
+ )
50
+ comment = "Success"
51
+ except Exception as e:
52
+ torch.cuda.empty_cache()
53
+ comment = f"Error. error information is {str(e)}"
54
+
55
+ return {"message": comment}
56
+
57
+ def save_base64_video(base64_string):
58
+ video_data = base64.b64decode(base64_string)
59
+
60
+ md5_hash = hashlib.md5(video_data).hexdigest()
61
+ filename = f"{md5_hash}.mp4"
62
+
63
+ temp_dir = tempfile.gettempdir()
64
+ file_path = os.path.join(temp_dir, filename)
65
+
66
+ with open(file_path, 'wb') as video_file:
67
+ video_file.write(video_data)
68
+
69
+ return file_path
70
+
71
+ def save_base64_image(base64_string):
72
+ video_data = base64.b64decode(base64_string)
73
+
74
+ md5_hash = hashlib.md5(video_data).hexdigest()
75
+ filename = f"{md5_hash}.jpg"
76
+
77
+ temp_dir = tempfile.gettempdir()
78
+ file_path = os.path.join(temp_dir, filename)
79
+
80
+ with open(file_path, 'wb') as video_file:
81
+ video_file.write(video_data)
82
+
83
+ return file_path
84
+
85
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
86
+ @app.post("/cogvideox_fun/infer_forward")
87
+ def _infer_forward_api(
88
+ datas: dict,
89
+ ):
90
+ base_model_path = datas.get('base_model_path', 'none')
91
+ lora_model_path = datas.get('lora_model_path', 'none')
92
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
93
+ prompt_textbox = datas.get('prompt_textbox', None)
94
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
95
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
96
+ sample_step_slider = datas.get('sample_step_slider', 30)
97
+ resize_method = datas.get('resize_method', "Generate by")
98
+ width_slider = datas.get('width_slider', 672)
99
+ height_slider = datas.get('height_slider', 384)
100
+ base_resolution = datas.get('base_resolution', 512)
101
+ is_image = datas.get('is_image', False)
102
+ generation_method = datas.get('generation_method', False)
103
+ length_slider = datas.get('length_slider', 49)
104
+ overlap_video_length = datas.get('overlap_video_length', 4)
105
+ partial_video_length = datas.get('partial_video_length', 72)
106
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
107
+ start_image = datas.get('start_image', None)
108
+ end_image = datas.get('end_image', None)
109
+ validation_video = datas.get('validation_video', None)
110
+ validation_video_mask = datas.get('validation_video_mask', None)
111
+ control_video = datas.get('control_video', None)
112
+ denoise_strength = datas.get('denoise_strength', 0.70)
113
+ seed_textbox = datas.get("seed_textbox", 43)
114
+
115
+ generation_method = "Image Generation" if is_image else generation_method
116
+
117
+ if start_image is not None:
118
+ start_image = base64.b64decode(start_image)
119
+ start_image = [Image.open(BytesIO(start_image))]
120
+
121
+ if end_image is not None:
122
+ end_image = base64.b64decode(end_image)
123
+ end_image = [Image.open(BytesIO(end_image))]
124
+
125
+ if validation_video is not None:
126
+ validation_video = save_base64_video(validation_video)
127
+
128
+ if validation_video_mask is not None:
129
+ validation_video_mask = save_base64_image(validation_video_mask)
130
+
131
+ if control_video is not None:
132
+ control_video = save_base64_video(control_video)
133
+
134
+ try:
135
+ save_sample_path, comment = controller.generate(
136
+ "",
137
+ base_model_path,
138
+ lora_model_path,
139
+ lora_alpha_slider,
140
+ prompt_textbox,
141
+ negative_prompt_textbox,
142
+ sampler_dropdown,
143
+ sample_step_slider,
144
+ resize_method,
145
+ width_slider,
146
+ height_slider,
147
+ base_resolution,
148
+ generation_method,
149
+ length_slider,
150
+ overlap_video_length,
151
+ partial_video_length,
152
+ cfg_scale_slider,
153
+ start_image,
154
+ end_image,
155
+ validation_video,
156
+ validation_video_mask,
157
+ control_video,
158
+ denoise_strength,
159
+ seed_textbox,
160
+ is_api = True,
161
+ )
162
+ except Exception as e:
163
+ gc.collect()
164
+ torch.cuda.empty_cache()
165
+ torch.cuda.ipc_collect()
166
+ save_sample_path = ""
167
+ comment = f"Error. error information is {str(e)}"
168
+ return {"message": comment}
169
+
170
+ if save_sample_path != "":
171
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
172
+ else:
173
+ return {"message": comment, "save_sample_path": save_sample_path}
cogvideox/api/post_infer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import sys
4
+ import time
5
+ from datetime import datetime
6
+ from io import BytesIO
7
+
8
+ import cv2
9
+ import requests
10
+ import base64
11
+
12
+
13
+ def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
14
+ datas = json.dumps({
15
+ "diffusion_transformer_path": diffusion_transformer_path
16
+ })
17
+ r = requests.post(f'{url}/cogvideox_fun/update_diffusion_transformer', data=datas, timeout=1500)
18
+ data = r.content.decode('utf-8')
19
+ return data
20
+
21
+ def post_update_edition(edition, url='http://0.0.0.0:7860'):
22
+ datas = json.dumps({
23
+ "edition": edition
24
+ })
25
+ r = requests.post(f'{url}/cogvideox_fun/update_edition', data=datas, timeout=1500)
26
+ data = r.content.decode('utf-8')
27
+ return data
28
+
29
+ def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
30
+ datas = json.dumps({
31
+ "base_model_path": "none",
32
+ "motion_module_path": "none",
33
+ "lora_model_path": "none",
34
+ "lora_alpha_slider": 0.55,
35
+ "prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
36
+ "negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ",
37
+ "sampler_dropdown": "Euler",
38
+ "sample_step_slider": 50,
39
+ "width_slider": 672,
40
+ "height_slider": 384,
41
+ "generation_method": "Video Generation",
42
+ "length_slider": length_slider,
43
+ "cfg_scale_slider": 6,
44
+ "seed_textbox": 43,
45
+ })
46
+ r = requests.post(f'{url}/cogvideox_fun/infer_forward', data=datas, timeout=1500)
47
+ data = r.content.decode('utf-8')
48
+ return data
49
+
50
+ if __name__ == '__main__':
51
+ # initiate time
52
+ now_date = datetime.now()
53
+ time_start = time.time()
54
+
55
+ # -------------------------- #
56
+ # Step 1: update edition
57
+ # -------------------------- #
58
+ diffusion_transformer_path = "models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
59
+ outputs = post_diffusion_transformer(diffusion_transformer_path)
60
+ print('Output update edition: ', outputs)
61
+
62
+ # -------------------------- #
63
+ # Step 2: infer
64
+ # -------------------------- #
65
+ # "Video Generation" and "Image Generation"
66
+ generation_method = "Video Generation"
67
+ length_slider = 49
68
+ outputs = post_infer(generation_method, length_slider)
69
+
70
+ # Get decoded data
71
+ outputs = json.loads(outputs)
72
+ base64_encoding = outputs["base64_encoding"]
73
+ decoded_data = base64.b64decode(base64_encoding)
74
+
75
+ is_image = True if generation_method == "Image Generation" else False
76
+ if is_image or length_slider == 1:
77
+ file_path = "1.png"
78
+ else:
79
+ file_path = "1.mp4"
80
+ with open(file_path, "wb") as file:
81
+ file.write(decoded_data)
82
+
83
+ # End of record time
84
+ # The calculated time difference is the execution time of the program, expressed in seconds / s
85
+ time_end = time.time()
86
+ time_sum = (time_end - time_start) % 60
87
+ print('# --------------------------------------------------------- #')
88
+ print(f'# Total expenditure: {time_sum}s')
89
+ print('# --------------------------------------------------------- #')
cogvideox/data/bucket_sampler.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ ASPECT_RATIO_512 = {
13
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
+ }
24
+ ASPECT_RATIO_RANDOM_CROP_512 = {
25
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
+ }
31
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
32
+ 1, 2,
33
+ 4, 4, 4, 4,
34
+ 8, 8, 8,
35
+ 4, 4, 4, 4,
36
+ 2, 1
37
+ ]
38
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
+
40
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
+ aspect_ratio = height / width
42
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
+ return ratios[closest_ratio], float(closest_ratio)
44
+
45
+ def get_image_size_without_loading(path):
46
+ with Image.open(path) as img:
47
+ return img.size # (width, height)
48
+
49
+ class RandomSampler(Sampler[int]):
50
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
+
52
+ If with replacement, then user can specify :attr:`num_samples` to draw.
53
+
54
+ Args:
55
+ data_source (Dataset): dataset to sample from
56
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
58
+ generator (Generator): Generator used in sampling.
59
+ """
60
+
61
+ data_source: Sized
62
+ replacement: bool
63
+
64
+ def __init__(self, data_source: Sized, replacement: bool = False,
65
+ num_samples: Optional[int] = None, generator=None) -> None:
66
+ self.data_source = data_source
67
+ self.replacement = replacement
68
+ self._num_samples = num_samples
69
+ self.generator = generator
70
+ self._pos_start = 0
71
+
72
+ if not isinstance(self.replacement, bool):
73
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
+
75
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
+
78
+ @property
79
+ def num_samples(self) -> int:
80
+ # dataset size might change at runtime
81
+ if self._num_samples is None:
82
+ return len(self.data_source)
83
+ return self._num_samples
84
+
85
+ def __iter__(self) -> Iterator[int]:
86
+ n = len(self.data_source)
87
+ if self.generator is None:
88
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
+ generator = torch.Generator()
90
+ generator.manual_seed(seed)
91
+ else:
92
+ generator = self.generator
93
+
94
+ if self.replacement:
95
+ for _ in range(self.num_samples // 32):
96
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
+ else:
99
+ for _ in range(self.num_samples // n):
100
+ xx = torch.randperm(n, generator=generator).tolist()
101
+ if self._pos_start >= n:
102
+ self._pos_start = 0
103
+ print("xx top 10", xx[:10], self._pos_start)
104
+ for idx in range(self._pos_start, n):
105
+ yield xx[idx]
106
+ self._pos_start = (self._pos_start + 1) % n
107
+ self._pos_start = 0
108
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
+
110
+ def __len__(self) -> int:
111
+ return self.num_samples
112
+
113
+ class AspectRatioBatchImageSampler(BatchSampler):
114
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
+
116
+ Args:
117
+ sampler (Sampler): Base sampler.
118
+ dataset (Dataset): Dataset providing data information.
119
+ batch_size (int): Size of mini-batch.
120
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
121
+ its size would be less than ``batch_size``.
122
+ aspect_ratios (dict): The predefined aspect ratios.
123
+ """
124
+ def __init__(
125
+ self,
126
+ sampler: Sampler,
127
+ dataset: Dataset,
128
+ batch_size: int,
129
+ train_folder: str = None,
130
+ aspect_ratios: dict = ASPECT_RATIO_512,
131
+ drop_last: bool = False,
132
+ config=None,
133
+ **kwargs
134
+ ) -> None:
135
+ if not isinstance(sampler, Sampler):
136
+ raise TypeError('sampler should be an instance of ``Sampler``, '
137
+ f'but got {sampler}')
138
+ if not isinstance(batch_size, int) or batch_size <= 0:
139
+ raise ValueError('batch_size should be a positive integer value, '
140
+ f'but got batch_size={batch_size}')
141
+ self.sampler = sampler
142
+ self.dataset = dataset
143
+ self.train_folder = train_folder
144
+ self.batch_size = batch_size
145
+ self.aspect_ratios = aspect_ratios
146
+ self.drop_last = drop_last
147
+ self.config = config
148
+ # buckets for each aspect ratio
149
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
+ # [str(k) for k, v in aspect_ratios]
151
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
152
+
153
+ def __iter__(self):
154
+ for idx in self.sampler:
155
+ try:
156
+ image_dict = self.dataset[idx]
157
+
158
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
159
+ if width is None or height is None:
160
+ image_id, name = image_dict['file_path'], image_dict['text']
161
+ if self.train_folder is None:
162
+ image_dir = image_id
163
+ else:
164
+ image_dir = os.path.join(self.train_folder, image_id)
165
+
166
+ width, height = get_image_size_without_loading(image_dir)
167
+
168
+ ratio = height / width # self.dataset[idx]
169
+ else:
170
+ height = int(height)
171
+ width = int(width)
172
+ ratio = height / width # self.dataset[idx]
173
+ except Exception as e:
174
+ print(e)
175
+ continue
176
+ # find the closest aspect ratio
177
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
+ if closest_ratio not in self.current_available_bucket_keys:
179
+ continue
180
+ bucket = self._aspect_ratio_buckets[closest_ratio]
181
+ bucket.append(idx)
182
+ # yield a batch of indices in the same aspect ratio group
183
+ if len(bucket) == self.batch_size:
184
+ yield bucket[:]
185
+ del bucket[:]
186
+
187
+ class AspectRatioBatchSampler(BatchSampler):
188
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
+
190
+ Args:
191
+ sampler (Sampler): Base sampler.
192
+ dataset (Dataset): Dataset providing data information.
193
+ batch_size (int): Size of mini-batch.
194
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
195
+ its size would be less than ``batch_size``.
196
+ aspect_ratios (dict): The predefined aspect ratios.
197
+ """
198
+ def __init__(
199
+ self,
200
+ sampler: Sampler,
201
+ dataset: Dataset,
202
+ batch_size: int,
203
+ video_folder: str = None,
204
+ train_data_format: str = "webvid",
205
+ aspect_ratios: dict = ASPECT_RATIO_512,
206
+ drop_last: bool = False,
207
+ config=None,
208
+ **kwargs
209
+ ) -> None:
210
+ if not isinstance(sampler, Sampler):
211
+ raise TypeError('sampler should be an instance of ``Sampler``, '
212
+ f'but got {sampler}')
213
+ if not isinstance(batch_size, int) or batch_size <= 0:
214
+ raise ValueError('batch_size should be a positive integer value, '
215
+ f'but got batch_size={batch_size}')
216
+ self.sampler = sampler
217
+ self.dataset = dataset
218
+ self.video_folder = video_folder
219
+ self.train_data_format = train_data_format
220
+ self.batch_size = batch_size
221
+ self.aspect_ratios = aspect_ratios
222
+ self.drop_last = drop_last
223
+ self.config = config
224
+ # buckets for each aspect ratio
225
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
+ # [str(k) for k, v in aspect_ratios]
227
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
228
+
229
+ def __iter__(self):
230
+ for idx in self.sampler:
231
+ try:
232
+ video_dict = self.dataset[idx]
233
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
234
+
235
+ if width is None or height is None:
236
+ if self.train_data_format == "normal":
237
+ video_id, name = video_dict['file_path'], video_dict['text']
238
+ if self.video_folder is None:
239
+ video_dir = video_id
240
+ else:
241
+ video_dir = os.path.join(self.video_folder, video_id)
242
+ else:
243
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
+ cap = cv2.VideoCapture(video_dir)
246
+
247
+ # 获取视频尺寸
248
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
+
251
+ ratio = height / width # self.dataset[idx]
252
+ else:
253
+ height = int(height)
254
+ width = int(width)
255
+ ratio = height / width # self.dataset[idx]
256
+ except Exception as e:
257
+ print(e)
258
+ continue
259
+ # find the closest aspect ratio
260
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
+ if closest_ratio not in self.current_available_bucket_keys:
262
+ continue
263
+ bucket = self._aspect_ratio_buckets[closest_ratio]
264
+ bucket.append(idx)
265
+ # yield a batch of indices in the same aspect ratio group
266
+ if len(bucket) == self.batch_size:
267
+ yield bucket[:]
268
+ del bucket[:]
269
+
270
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
271
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
+
273
+ Args:
274
+ sampler (Sampler): Base sampler.
275
+ dataset (Dataset): Dataset providing data information.
276
+ batch_size (int): Size of mini-batch.
277
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
278
+ its size would be less than ``batch_size``.
279
+ aspect_ratios (dict): The predefined aspect ratios.
280
+ """
281
+
282
+ def __init__(self,
283
+ sampler: Sampler,
284
+ dataset: Dataset,
285
+ batch_size: int,
286
+ train_folder: str = None,
287
+ aspect_ratios: dict = ASPECT_RATIO_512,
288
+ drop_last: bool = False
289
+ ) -> None:
290
+ if not isinstance(sampler, Sampler):
291
+ raise TypeError('sampler should be an instance of ``Sampler``, '
292
+ f'but got {sampler}')
293
+ if not isinstance(batch_size, int) or batch_size <= 0:
294
+ raise ValueError('batch_size should be a positive integer value, '
295
+ f'but got batch_size={batch_size}')
296
+ self.sampler = sampler
297
+ self.dataset = dataset
298
+ self.train_folder = train_folder
299
+ self.batch_size = batch_size
300
+ self.aspect_ratios = aspect_ratios
301
+ self.drop_last = drop_last
302
+
303
+ # buckets for each aspect ratio
304
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
305
+ self.bucket = {
306
+ 'image':{ratio: [] for ratio in aspect_ratios},
307
+ 'video':{ratio: [] for ratio in aspect_ratios}
308
+ }
309
+
310
+ def __iter__(self):
311
+ for idx in self.sampler:
312
+ content_type = self.dataset[idx].get('type', 'image')
313
+ if content_type == 'image':
314
+ try:
315
+ image_dict = self.dataset[idx]
316
+
317
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
318
+ if width is None or height is None:
319
+ image_id, name = image_dict['file_path'], image_dict['text']
320
+ if self.train_folder is None:
321
+ image_dir = image_id
322
+ else:
323
+ image_dir = os.path.join(self.train_folder, image_id)
324
+
325
+ width, height = get_image_size_without_loading(image_dir)
326
+
327
+ ratio = height / width # self.dataset[idx]
328
+ else:
329
+ height = int(height)
330
+ width = int(width)
331
+ ratio = height / width # self.dataset[idx]
332
+ except Exception as e:
333
+ print(e)
334
+ continue
335
+ # find the closest aspect ratio
336
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
+ if closest_ratio not in self.current_available_bucket_keys:
338
+ continue
339
+ bucket = self.bucket['image'][closest_ratio]
340
+ bucket.append(idx)
341
+ # yield a batch of indices in the same aspect ratio group
342
+ if len(bucket) == self.batch_size:
343
+ yield bucket[:]
344
+ del bucket[:]
345
+ else:
346
+ try:
347
+ video_dict = self.dataset[idx]
348
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
349
+
350
+ if width is None or height is None:
351
+ video_id, name = video_dict['file_path'], video_dict['text']
352
+ if self.train_folder is None:
353
+ video_dir = video_id
354
+ else:
355
+ video_dir = os.path.join(self.train_folder, video_id)
356
+ cap = cv2.VideoCapture(video_dir)
357
+
358
+ # 获取视频尺寸
359
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
+
362
+ ratio = height / width # self.dataset[idx]
363
+ else:
364
+ height = int(height)
365
+ width = int(width)
366
+ ratio = height / width # self.dataset[idx]
367
+ except Exception as e:
368
+ print(e)
369
+ continue
370
+ # find the closest aspect ratio
371
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
+ if closest_ratio not in self.current_available_bucket_keys:
373
+ continue
374
+ bucket = self.bucket['video'][closest_ratio]
375
+ bucket.append(idx)
376
+ # yield a batch of indices in the same aspect ratio group
377
+ if len(bucket) == self.batch_size:
378
+ yield bucket[:]
379
+ del bucket[:]
cogvideox/data/dataset_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ if __name__ == "__main__":
69
+ dataset = CC15M(
70
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71
+ resolution=512,
72
+ )
73
+
74
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75
+ for idx, batch in enumerate(dataloader):
76
+ print(batch["pixel_values"].shape, len(batch["text"]))
cogvideox/data/dataset_image_video.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ from threading import Thread
8
+
9
+ import albumentations
10
+ import cv2
11
+ import gc
12
+ import numpy as np
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+
16
+ from func_timeout import func_timeout, FunctionTimedOut
17
+ from decord import VideoReader
18
+ from PIL import Image
19
+ from torch.utils.data import BatchSampler, Sampler
20
+ from torch.utils.data.dataset import Dataset
21
+ from contextlib import contextmanager
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ if f != 1:
29
+ mask_index = np.random.choice([0, 1, 2, 3, 4], p = [0.05, 0.3, 0.3, 0.3, 0.05]) # np.random.randint(0, 5)
30
+ else:
31
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) # np.random.randint(0, 2)
32
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
33
+
34
+ if mask_index == 0:
35
+ center_x = torch.randint(0, w, (1,)).item()
36
+ center_y = torch.randint(0, h, (1,)).item()
37
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
38
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
39
+
40
+ start_x = max(center_x - block_size_x // 2, 0)
41
+ end_x = min(center_x + block_size_x // 2, w)
42
+ start_y = max(center_y - block_size_y // 2, 0)
43
+ end_y = min(center_y + block_size_y // 2, h)
44
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
45
+ elif mask_index == 1:
46
+ mask[:, :, :, :] = 1
47
+ elif mask_index == 2:
48
+ mask_frame_index = np.random.randint(1, 5)
49
+ mask[mask_frame_index:, :, :, :] = 1
50
+ elif mask_index == 3:
51
+ mask_frame_index = np.random.randint(1, 5)
52
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
53
+ elif mask_index == 4:
54
+ center_x = torch.randint(0, w, (1,)).item()
55
+ center_y = torch.randint(0, h, (1,)).item()
56
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
57
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
58
+
59
+ start_x = max(center_x - block_size_x // 2, 0)
60
+ end_x = min(center_x + block_size_x // 2, w)
61
+ start_y = max(center_y - block_size_y // 2, 0)
62
+ end_y = min(center_y + block_size_y // 2, h)
63
+
64
+ mask_frame_before = np.random.randint(0, f // 2)
65
+ mask_frame_after = np.random.randint(f // 2, f)
66
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
67
+ else:
68
+ raise ValueError(f"The mask_index {mask_index} is not define")
69
+ return mask
70
+
71
+ class ImageVideoSampler(BatchSampler):
72
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
73
+
74
+ Args:
75
+ sampler (Sampler): Base sampler.
76
+ dataset (Dataset): Dataset providing data information.
77
+ batch_size (int): Size of mini-batch.
78
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
79
+ its size would be less than ``batch_size``.
80
+ aspect_ratios (dict): The predefined aspect ratios.
81
+ """
82
+
83
+ def __init__(self,
84
+ sampler: Sampler,
85
+ dataset: Dataset,
86
+ batch_size: int,
87
+ drop_last: bool = False
88
+ ) -> None:
89
+ if not isinstance(sampler, Sampler):
90
+ raise TypeError('sampler should be an instance of ``Sampler``, '
91
+ f'but got {sampler}')
92
+ if not isinstance(batch_size, int) or batch_size <= 0:
93
+ raise ValueError('batch_size should be a positive integer value, '
94
+ f'but got batch_size={batch_size}')
95
+ self.sampler = sampler
96
+ self.dataset = dataset
97
+ self.batch_size = batch_size
98
+ self.drop_last = drop_last
99
+
100
+ # buckets for each aspect ratio
101
+ self.bucket = {'image':[], 'video':[]}
102
+
103
+ def __iter__(self):
104
+ for idx in self.sampler:
105
+ content_type = self.dataset.dataset[idx].get('type', 'image')
106
+ self.bucket[content_type].append(idx)
107
+
108
+ # yield a batch of indices in the same aspect ratio group
109
+ if len(self.bucket['video']) == self.batch_size:
110
+ bucket = self.bucket['video']
111
+ yield bucket[:]
112
+ del bucket[:]
113
+ elif len(self.bucket['image']) == self.batch_size:
114
+ bucket = self.bucket['image']
115
+ yield bucket[:]
116
+ del bucket[:]
117
+
118
+ @contextmanager
119
+ def VideoReader_contextmanager(*args, **kwargs):
120
+ vr = VideoReader(*args, **kwargs)
121
+ try:
122
+ yield vr
123
+ finally:
124
+ del vr
125
+ gc.collect()
126
+
127
+ def get_video_reader_batch(video_reader, batch_index):
128
+ frames = video_reader.get_batch(batch_index).asnumpy()
129
+ return frames
130
+
131
+ def resize_frame(frame, target_short_side):
132
+ h, w, _ = frame.shape
133
+ if h < w:
134
+ if target_short_side > h:
135
+ return frame
136
+ new_h = target_short_side
137
+ new_w = int(target_short_side * w / h)
138
+ else:
139
+ if target_short_side > w:
140
+ return frame
141
+ new_w = target_short_side
142
+ new_h = int(target_short_side * h / w)
143
+
144
+ resized_frame = cv2.resize(frame, (new_w, new_h))
145
+ return resized_frame
146
+
147
+ class ImageVideoDataset(Dataset):
148
+ def __init__(
149
+ self,
150
+ ann_path, data_root=None,
151
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
152
+ image_sample_size=512,
153
+ video_repeat=0,
154
+ text_drop_ratio=-1,
155
+ enable_bucket=False,
156
+ video_length_drop_start=0.1,
157
+ video_length_drop_end=0.9,
158
+ enable_inpaint=False,
159
+ ):
160
+ # Loading annotations from files
161
+ print(f"loading annotations from {ann_path} ...")
162
+ if ann_path.endswith('.csv'):
163
+ with open(ann_path, 'r') as csvfile:
164
+ dataset = list(csv.DictReader(csvfile))
165
+ elif ann_path.endswith('.json'):
166
+ dataset = json.load(open(ann_path))
167
+
168
+ self.data_root = data_root
169
+
170
+ # It's used to balance num of images and videos.
171
+ self.dataset = []
172
+ for data in dataset:
173
+ if data.get('type', 'image') != 'video':
174
+ self.dataset.append(data)
175
+ if video_repeat > 0:
176
+ for _ in range(video_repeat):
177
+ for data in dataset:
178
+ if data.get('type', 'image') == 'video':
179
+ self.dataset.append(data)
180
+ del dataset
181
+
182
+ self.length = len(self.dataset)
183
+ print(f"data scale: {self.length}")
184
+ # TODO: enable bucket training
185
+ self.enable_bucket = enable_bucket
186
+ self.text_drop_ratio = text_drop_ratio
187
+ self.enable_inpaint = enable_inpaint
188
+
189
+ self.video_length_drop_start = video_length_drop_start
190
+ self.video_length_drop_end = video_length_drop_end
191
+
192
+ # Video params
193
+ self.video_sample_stride = video_sample_stride
194
+ self.video_sample_n_frames = video_sample_n_frames
195
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
196
+ self.video_transforms = transforms.Compose(
197
+ [
198
+ transforms.Resize(min(self.video_sample_size)),
199
+ transforms.CenterCrop(self.video_sample_size),
200
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
201
+ ]
202
+ )
203
+
204
+ # Image params
205
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
206
+ self.image_transforms = transforms.Compose([
207
+ transforms.Resize(min(self.image_sample_size)),
208
+ transforms.CenterCrop(self.image_sample_size),
209
+ transforms.ToTensor(),
210
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
211
+ ])
212
+
213
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
214
+
215
+ def get_batch(self, idx):
216
+ data_info = self.dataset[idx % len(self.dataset)]
217
+
218
+ if data_info.get('type', 'image')=='video':
219
+ video_id, text = data_info['file_path'], data_info['text']
220
+
221
+ if self.data_root is None:
222
+ video_dir = video_id
223
+ else:
224
+ video_dir = os.path.join(self.data_root, video_id)
225
+
226
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
227
+ min_sample_n_frames = min(
228
+ self.video_sample_n_frames,
229
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
230
+ )
231
+ if min_sample_n_frames == 0:
232
+ raise ValueError(f"No Frames in video.")
233
+
234
+ video_length = int(self.video_length_drop_end * len(video_reader))
235
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
236
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
237
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
238
+
239
+ try:
240
+ sample_args = (video_reader, batch_index)
241
+ pixel_values = func_timeout(
242
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
243
+ )
244
+ resized_frames = []
245
+ for i in range(len(pixel_values)):
246
+ frame = pixel_values[i]
247
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
248
+ resized_frames.append(resized_frame)
249
+ pixel_values = np.array(resized_frames)
250
+ except FunctionTimedOut:
251
+ raise ValueError(f"Read {idx} timeout.")
252
+ except Exception as e:
253
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
254
+
255
+ if not self.enable_bucket:
256
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
257
+ pixel_values = pixel_values / 255.
258
+ del video_reader
259
+ else:
260
+ pixel_values = pixel_values
261
+
262
+ if not self.enable_bucket:
263
+ pixel_values = self.video_transforms(pixel_values)
264
+
265
+ # Random use no text generation
266
+ if random.random() < self.text_drop_ratio:
267
+ text = ''
268
+ return pixel_values, text, 'video'
269
+ else:
270
+ image_path, text = data_info['file_path'], data_info['text']
271
+ if self.data_root is not None:
272
+ image_path = os.path.join(self.data_root, image_path)
273
+ image = Image.open(image_path).convert('RGB')
274
+ if not self.enable_bucket:
275
+ image = self.image_transforms(image).unsqueeze(0)
276
+ else:
277
+ image = np.expand_dims(np.array(image), 0)
278
+ if random.random() < self.text_drop_ratio:
279
+ text = ''
280
+ return image, text, 'image'
281
+
282
+ def __len__(self):
283
+ return self.length
284
+
285
+ def __getitem__(self, idx):
286
+ data_info = self.dataset[idx % len(self.dataset)]
287
+ data_type = data_info.get('type', 'image')
288
+ while True:
289
+ sample = {}
290
+ try:
291
+ data_info_local = self.dataset[idx % len(self.dataset)]
292
+ data_type_local = data_info_local.get('type', 'image')
293
+ if data_type_local != data_type:
294
+ raise ValueError("data_type_local != data_type")
295
+
296
+ pixel_values, name, data_type = self.get_batch(idx)
297
+ sample["pixel_values"] = pixel_values
298
+ sample["text"] = name
299
+ sample["data_type"] = data_type
300
+ sample["idx"] = idx
301
+
302
+ if len(sample) > 0:
303
+ break
304
+ except Exception as e:
305
+ print(e, self.dataset[idx % len(self.dataset)])
306
+ idx = random.randint(0, self.length-1)
307
+
308
+ if self.enable_inpaint and not self.enable_bucket:
309
+ mask = get_random_mask(pixel_values.size())
310
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
311
+ sample["mask_pixel_values"] = mask_pixel_values
312
+ sample["mask"] = mask
313
+
314
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
315
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
316
+ sample["clip_pixel_values"] = clip_pixel_values
317
+
318
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
319
+ if (mask == 1).all():
320
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
321
+ sample["ref_pixel_values"] = ref_pixel_values
322
+
323
+ return sample
324
+
325
+
326
+ class ImageVideoControlDataset(Dataset):
327
+ def __init__(
328
+ self,
329
+ ann_path, data_root=None,
330
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
331
+ image_sample_size=512,
332
+ video_repeat=0,
333
+ text_drop_ratio=-1,
334
+ enable_bucket=False,
335
+ video_length_drop_start=0.1,
336
+ video_length_drop_end=0.9,
337
+ enable_inpaint=False,
338
+ ):
339
+ # Loading annotations from files
340
+ print(f"loading annotations from {ann_path} ...")
341
+ if ann_path.endswith('.csv'):
342
+ with open(ann_path, 'r') as csvfile:
343
+ dataset = list(csv.DictReader(csvfile))
344
+ elif ann_path.endswith('.json'):
345
+ dataset = json.load(open(ann_path))
346
+
347
+ self.data_root = data_root
348
+
349
+ # It's used to balance num of images and videos.
350
+ self.dataset = []
351
+ for data in dataset:
352
+ if data.get('type', 'image') != 'video':
353
+ self.dataset.append(data)
354
+ if video_repeat > 0:
355
+ for _ in range(video_repeat):
356
+ for data in dataset:
357
+ if data.get('type', 'image') == 'video':
358
+ self.dataset.append(data)
359
+ del dataset
360
+
361
+ self.length = len(self.dataset)
362
+ print(f"data scale: {self.length}")
363
+ # TODO: enable bucket training
364
+ self.enable_bucket = enable_bucket
365
+ self.text_drop_ratio = text_drop_ratio
366
+ self.enable_inpaint = enable_inpaint
367
+
368
+ self.video_length_drop_start = video_length_drop_start
369
+ self.video_length_drop_end = video_length_drop_end
370
+
371
+ # Video params
372
+ self.video_sample_stride = video_sample_stride
373
+ self.video_sample_n_frames = video_sample_n_frames
374
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
375
+ self.video_transforms = transforms.Compose(
376
+ [
377
+ transforms.Resize(min(self.video_sample_size)),
378
+ transforms.CenterCrop(self.video_sample_size),
379
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
380
+ ]
381
+ )
382
+
383
+ # Image params
384
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
385
+ self.image_transforms = transforms.Compose([
386
+ transforms.Resize(min(self.image_sample_size)),
387
+ transforms.CenterCrop(self.image_sample_size),
388
+ transforms.ToTensor(),
389
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
390
+ ])
391
+
392
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
393
+
394
+ def get_batch(self, idx):
395
+ data_info = self.dataset[idx % len(self.dataset)]
396
+ video_id, text = data_info['file_path'], data_info['text']
397
+
398
+ if data_info.get('type', 'image')=='video':
399
+ if self.data_root is None:
400
+ video_dir = video_id
401
+ else:
402
+ video_dir = os.path.join(self.data_root, video_id)
403
+
404
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
405
+ min_sample_n_frames = min(
406
+ self.video_sample_n_frames,
407
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
408
+ )
409
+ if min_sample_n_frames == 0:
410
+ raise ValueError(f"No Frames in video.")
411
+
412
+ video_length = int(self.video_length_drop_end * len(video_reader))
413
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
414
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
415
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
416
+
417
+ try:
418
+ sample_args = (video_reader, batch_index)
419
+ pixel_values = func_timeout(
420
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
421
+ )
422
+ resized_frames = []
423
+ for i in range(len(pixel_values)):
424
+ frame = pixel_values[i]
425
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
426
+ resized_frames.append(resized_frame)
427
+ pixel_values = np.array(resized_frames)
428
+ except FunctionTimedOut:
429
+ raise ValueError(f"Read {idx} timeout.")
430
+ except Exception as e:
431
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
432
+
433
+ if not self.enable_bucket:
434
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
435
+ pixel_values = pixel_values / 255.
436
+ del video_reader
437
+ else:
438
+ pixel_values = pixel_values
439
+
440
+ if not self.enable_bucket:
441
+ pixel_values = self.video_transforms(pixel_values)
442
+
443
+ # Random use no text generation
444
+ if random.random() < self.text_drop_ratio:
445
+ text = ''
446
+
447
+ control_video_id = data_info['control_file_path']
448
+
449
+ if self.data_root is None:
450
+ control_video_id = control_video_id
451
+ else:
452
+ control_video_id = os.path.join(self.data_root, control_video_id)
453
+
454
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
455
+ try:
456
+ sample_args = (control_video_reader, batch_index)
457
+ control_pixel_values = func_timeout(
458
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
459
+ )
460
+ resized_frames = []
461
+ for i in range(len(control_pixel_values)):
462
+ frame = control_pixel_values[i]
463
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
464
+ resized_frames.append(resized_frame)
465
+ control_pixel_values = np.array(resized_frames)
466
+ except FunctionTimedOut:
467
+ raise ValueError(f"Read {idx} timeout.")
468
+ except Exception as e:
469
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
470
+
471
+ if not self.enable_bucket:
472
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
473
+ control_pixel_values = control_pixel_values / 255.
474
+ del control_video_reader
475
+ else:
476
+ control_pixel_values = control_pixel_values
477
+
478
+ if not self.enable_bucket:
479
+ control_pixel_values = self.video_transforms(control_pixel_values)
480
+ return pixel_values, control_pixel_values, text, "video"
481
+ else:
482
+ image_path, text = data_info['file_path'], data_info['text']
483
+ if self.data_root is not None:
484
+ image_path = os.path.join(self.data_root, image_path)
485
+ image = Image.open(image_path).convert('RGB')
486
+ if not self.enable_bucket:
487
+ image = self.image_transforms(image).unsqueeze(0)
488
+ else:
489
+ image = np.expand_dims(np.array(image), 0)
490
+
491
+ if random.random() < self.text_drop_ratio:
492
+ text = ''
493
+
494
+ control_image_id = data_info['control_file_path']
495
+
496
+ if self.data_root is None:
497
+ control_image_id = control_image_id
498
+ else:
499
+ control_image_id = os.path.join(self.data_root, control_image_id)
500
+
501
+ control_image = Image.open(control_image_id).convert('RGB')
502
+ if not self.enable_bucket:
503
+ control_image = self.image_transforms(control_image).unsqueeze(0)
504
+ else:
505
+ control_image = np.expand_dims(np.array(control_image), 0)
506
+ return image, control_image, text, 'image'
507
+
508
+ def __len__(self):
509
+ return self.length
510
+
511
+ def __getitem__(self, idx):
512
+ data_info = self.dataset[idx % len(self.dataset)]
513
+ data_type = data_info.get('type', 'image')
514
+ while True:
515
+ sample = {}
516
+ try:
517
+ data_info_local = self.dataset[idx % len(self.dataset)]
518
+ data_type_local = data_info_local.get('type', 'image')
519
+ if data_type_local != data_type:
520
+ raise ValueError("data_type_local != data_type")
521
+
522
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
523
+ sample["pixel_values"] = pixel_values
524
+ sample["control_pixel_values"] = control_pixel_values
525
+ sample["text"] = name
526
+ sample["data_type"] = data_type
527
+ sample["idx"] = idx
528
+
529
+ if len(sample) > 0:
530
+ break
531
+ except Exception as e:
532
+ print(e, self.dataset[idx % len(self.dataset)])
533
+ idx = random.randint(0, self.length-1)
534
+
535
+ if self.enable_inpaint and not self.enable_bucket:
536
+ mask = get_random_mask(pixel_values.size())
537
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
538
+ sample["mask_pixel_values"] = mask_pixel_values
539
+ sample["mask"] = mask
540
+
541
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
542
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
543
+ sample["clip_pixel_values"] = clip_pixel_values
544
+
545
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
546
+ if (mask == 1).all():
547
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
548
+ sample["ref_pixel_values"] = ref_pixel_values
549
+
550
+ return sample
cogvideox/data/dataset_video.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ from decord import VideoReader
17
+ from einops import rearrange
18
+ from func_timeout import FunctionTimedOut, func_timeout
19
+ from PIL import Image
20
+ from torch.utils.data import BatchSampler, Sampler
21
+ from torch.utils.data.dataset import Dataset
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ mask_index = np.random.randint(0, 4)
29
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
30
+ if mask_index == 0:
31
+ mask[1:, :, :, :] = 1
32
+ elif mask_index == 1:
33
+ mask_frame_index = 1
34
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
35
+ elif mask_index == 2:
36
+ center_x = torch.randint(0, w, (1,)).item()
37
+ center_y = torch.randint(0, h, (1,)).item()
38
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
39
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
40
+
41
+ start_x = max(center_x - block_size_x // 2, 0)
42
+ end_x = min(center_x + block_size_x // 2, w)
43
+ start_y = max(center_y - block_size_y // 2, 0)
44
+ end_y = min(center_y + block_size_y // 2, h)
45
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
46
+ elif mask_index == 3:
47
+ center_x = torch.randint(0, w, (1,)).item()
48
+ center_y = torch.randint(0, h, (1,)).item()
49
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
50
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
51
+
52
+ start_x = max(center_x - block_size_x // 2, 0)
53
+ end_x = min(center_x + block_size_x // 2, w)
54
+ start_y = max(center_y - block_size_y // 2, 0)
55
+ end_y = min(center_y + block_size_y // 2, h)
56
+
57
+ mask_frame_before = np.random.randint(0, f // 2)
58
+ mask_frame_after = np.random.randint(f // 2, f)
59
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
60
+ else:
61
+ raise ValueError(f"The mask_index {mask_index} is not define")
62
+ return mask
63
+
64
+
65
+ @contextmanager
66
+ def VideoReader_contextmanager(*args, **kwargs):
67
+ vr = VideoReader(*args, **kwargs)
68
+ try:
69
+ yield vr
70
+ finally:
71
+ del vr
72
+ gc.collect()
73
+
74
+
75
+ def get_video_reader_batch(video_reader, batch_index):
76
+ frames = video_reader.get_batch(batch_index).asnumpy()
77
+ return frames
78
+
79
+
80
+ class WebVid10M(Dataset):
81
+ def __init__(
82
+ self,
83
+ csv_path, video_folder,
84
+ sample_size=256, sample_stride=4, sample_n_frames=16,
85
+ enable_bucket=False, enable_inpaint=False, is_image=False,
86
+ ):
87
+ print(f"loading annotations from {csv_path} ...")
88
+ with open(csv_path, 'r') as csvfile:
89
+ self.dataset = list(csv.DictReader(csvfile))
90
+ self.length = len(self.dataset)
91
+ print(f"data scale: {self.length}")
92
+
93
+ self.video_folder = video_folder
94
+ self.sample_stride = sample_stride
95
+ self.sample_n_frames = sample_n_frames
96
+ self.enable_bucket = enable_bucket
97
+ self.enable_inpaint = enable_inpaint
98
+ self.is_image = is_image
99
+
100
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
101
+ self.pixel_transforms = transforms.Compose([
102
+ transforms.Resize(sample_size[0]),
103
+ transforms.CenterCrop(sample_size),
104
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
105
+ ])
106
+
107
+ def get_batch(self, idx):
108
+ video_dict = self.dataset[idx]
109
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
110
+
111
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
112
+ video_reader = VideoReader(video_dir)
113
+ video_length = len(video_reader)
114
+
115
+ if not self.is_image:
116
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
117
+ start_idx = random.randint(0, video_length - clip_length)
118
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
119
+ else:
120
+ batch_index = [random.randint(0, video_length - 1)]
121
+
122
+ if not self.enable_bucket:
123
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
124
+ pixel_values = pixel_values / 255.
125
+ del video_reader
126
+ else:
127
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
128
+
129
+ if self.is_image:
130
+ pixel_values = pixel_values[0]
131
+ return pixel_values, name
132
+
133
+ def __len__(self):
134
+ return self.length
135
+
136
+ def __getitem__(self, idx):
137
+ while True:
138
+ try:
139
+ pixel_values, name = self.get_batch(idx)
140
+ break
141
+
142
+ except Exception as e:
143
+ print("Error info:", e)
144
+ idx = random.randint(0, self.length-1)
145
+
146
+ if not self.enable_bucket:
147
+ pixel_values = self.pixel_transforms(pixel_values)
148
+ if self.enable_inpaint:
149
+ mask = get_random_mask(pixel_values.size())
150
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
151
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
152
+ else:
153
+ sample = dict(pixel_values=pixel_values, text=name)
154
+ return sample
155
+
156
+
157
+ class VideoDataset(Dataset):
158
+ def __init__(
159
+ self,
160
+ json_path, video_folder=None,
161
+ sample_size=256, sample_stride=4, sample_n_frames=16,
162
+ enable_bucket=False, enable_inpaint=False
163
+ ):
164
+ print(f"loading annotations from {json_path} ...")
165
+ self.dataset = json.load(open(json_path, 'r'))
166
+ self.length = len(self.dataset)
167
+ print(f"data scale: {self.length}")
168
+
169
+ self.video_folder = video_folder
170
+ self.sample_stride = sample_stride
171
+ self.sample_n_frames = sample_n_frames
172
+ self.enable_bucket = enable_bucket
173
+ self.enable_inpaint = enable_inpaint
174
+
175
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
176
+ self.pixel_transforms = transforms.Compose(
177
+ [
178
+ transforms.Resize(sample_size[0]),
179
+ transforms.CenterCrop(sample_size),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
181
+ ]
182
+ )
183
+
184
+ def get_batch(self, idx):
185
+ video_dict = self.dataset[idx]
186
+ video_id, name = video_dict['file_path'], video_dict['text']
187
+
188
+ if self.video_folder is None:
189
+ video_dir = video_id
190
+ else:
191
+ video_dir = os.path.join(self.video_folder, video_id)
192
+
193
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
194
+ video_length = len(video_reader)
195
+
196
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
197
+ start_idx = random.randint(0, video_length - clip_length)
198
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
199
+
200
+ try:
201
+ sample_args = (video_reader, batch_index)
202
+ pixel_values = func_timeout(
203
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
204
+ )
205
+ except FunctionTimedOut:
206
+ raise ValueError(f"Read {idx} timeout.")
207
+ except Exception as e:
208
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
209
+
210
+ if not self.enable_bucket:
211
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
212
+ pixel_values = pixel_values / 255.
213
+ del video_reader
214
+ else:
215
+ pixel_values = pixel_values
216
+
217
+ return pixel_values, name
218
+
219
+ def __len__(self):
220
+ return self.length
221
+
222
+ def __getitem__(self, idx):
223
+ while True:
224
+ try:
225
+ pixel_values, name = self.get_batch(idx)
226
+ break
227
+
228
+ except Exception as e:
229
+ print("Error info:", e)
230
+ idx = random.randint(0, self.length-1)
231
+
232
+ if not self.enable_bucket:
233
+ pixel_values = self.pixel_transforms(pixel_values)
234
+ if self.enable_inpaint:
235
+ mask = get_random_mask(pixel_values.size())
236
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
237
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
238
+ else:
239
+ sample = dict(pixel_values=pixel_values, text=name)
240
+ return sample
241
+
242
+
243
+ if __name__ == "__main__":
244
+ if 1:
245
+ dataset = VideoDataset(
246
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
247
+ sample_size=256,
248
+ sample_stride=4, sample_n_frames=16,
249
+ )
250
+
251
+ if 0:
252
+ dataset = WebVid10M(
253
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
254
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
255
+ sample_size=256,
256
+ sample_stride=4, sample_n_frames=16,
257
+ is_image=False,
258
+ )
259
+
260
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
261
+ for idx, batch in enumerate(dataloader):
262
+ print(batch["pixel_values"].shape, len(batch["text"]))
cogvideox/pipeline/pipeline_cogvideox.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
26
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```python
40
+ >>> import torch
41
+ >>> from diffusers import CogVideoX_Fun_Pipeline
42
+ >>> from diffusers.utils import export_to_video
43
+
44
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
45
+ >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
46
+ >>> prompt = (
47
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
48
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
49
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
50
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
51
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
52
+ ... "atmosphere of this unique musical performance."
53
+ ... )
54
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
55
+ >>> export_to_video(video, "output.mp4", fps=8)
56
+ ```
57
+ """
58
+
59
+
60
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
61
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
62
+ tw = tgt_width
63
+ th = tgt_height
64
+ h, w = src
65
+ r = h / w
66
+ if r > (th / tw):
67
+ resize_height = th
68
+ resize_width = int(round(th / h * w))
69
+ else:
70
+ resize_width = tw
71
+ resize_height = int(round(tw / w * h))
72
+
73
+ crop_top = int(round((th - resize_height) / 2.0))
74
+ crop_left = int(round((tw - resize_width) / 2.0))
75
+
76
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
77
+
78
+
79
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
80
+ def retrieve_timesteps(
81
+ scheduler,
82
+ num_inference_steps: Optional[int] = None,
83
+ device: Optional[Union[str, torch.device]] = None,
84
+ timesteps: Optional[List[int]] = None,
85
+ sigmas: Optional[List[float]] = None,
86
+ **kwargs,
87
+ ):
88
+ """
89
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
90
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
91
+
92
+ Args:
93
+ scheduler (`SchedulerMixin`):
94
+ The scheduler to get timesteps from.
95
+ num_inference_steps (`int`):
96
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
97
+ must be `None`.
98
+ device (`str` or `torch.device`, *optional*):
99
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
100
+ timesteps (`List[int]`, *optional*):
101
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
102
+ `num_inference_steps` and `sigmas` must be `None`.
103
+ sigmas (`List[float]`, *optional*):
104
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
105
+ `num_inference_steps` and `timesteps` must be `None`.
106
+
107
+ Returns:
108
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
109
+ second element is the number of inference steps.
110
+ """
111
+ if timesteps is not None and sigmas is not None:
112
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
113
+ if timesteps is not None:
114
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
115
+ if not accepts_timesteps:
116
+ raise ValueError(
117
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
118
+ f" timestep schedules. Please check whether you are using the correct scheduler."
119
+ )
120
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
121
+ timesteps = scheduler.timesteps
122
+ num_inference_steps = len(timesteps)
123
+ elif sigmas is not None:
124
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class CogVideoX_Fun_PipelineOutput(BaseOutput):
141
+ r"""
142
+ Output class for CogVideo pipelines.
143
+
144
+ Args:
145
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
146
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
147
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
148
+ `(batch_size, num_frames, channels, height, width)`.
149
+ """
150
+
151
+ videos: torch.Tensor
152
+
153
+
154
+ class CogVideoX_Fun_Pipeline(DiffusionPipeline):
155
+ r"""
156
+ Pipeline for text-to-video generation using CogVideoX_Fun.
157
+
158
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
159
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
160
+
161
+ Args:
162
+ vae ([`AutoencoderKL`]):
163
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
164
+ text_encoder ([`T5EncoderModel`]):
165
+ Frozen text-encoder. CogVideoX uses
166
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
167
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
168
+ tokenizer (`T5Tokenizer`):
169
+ Tokenizer of class
170
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
171
+ transformer ([`CogVideoXTransformer3DModel`]):
172
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
173
+ scheduler ([`SchedulerMixin`]):
174
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
175
+ """
176
+
177
+ _optional_components = []
178
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
179
+
180
+ _callback_tensor_inputs = [
181
+ "latents",
182
+ "prompt_embeds",
183
+ "negative_prompt_embeds",
184
+ ]
185
+
186
+ def __init__(
187
+ self,
188
+ tokenizer: T5Tokenizer,
189
+ text_encoder: T5EncoderModel,
190
+ vae: AutoencoderKLCogVideoX,
191
+ transformer: CogVideoXTransformer3DModel,
192
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
193
+ ):
194
+ super().__init__()
195
+
196
+ self.register_modules(
197
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
198
+ )
199
+ self.vae_scale_factor_spatial = (
200
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
201
+ )
202
+ self.vae_scale_factor_temporal = (
203
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
204
+ )
205
+
206
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
207
+
208
+ def _get_t5_prompt_embeds(
209
+ self,
210
+ prompt: Union[str, List[str]] = None,
211
+ num_videos_per_prompt: int = 1,
212
+ max_sequence_length: int = 226,
213
+ device: Optional[torch.device] = None,
214
+ dtype: Optional[torch.dtype] = None,
215
+ ):
216
+ device = device or self._execution_device
217
+ dtype = dtype or self.text_encoder.dtype
218
+
219
+ prompt = [prompt] if isinstance(prompt, str) else prompt
220
+ batch_size = len(prompt)
221
+
222
+ text_inputs = self.tokenizer(
223
+ prompt,
224
+ padding="max_length",
225
+ max_length=max_sequence_length,
226
+ truncation=True,
227
+ add_special_tokens=True,
228
+ return_tensors="pt",
229
+ )
230
+ text_input_ids = text_inputs.input_ids
231
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
232
+
233
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
234
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
235
+ logger.warning(
236
+ "The following part of your input was truncated because `max_sequence_length` is set to "
237
+ f" {max_sequence_length} tokens: {removed_text}"
238
+ )
239
+
240
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
241
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
242
+
243
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
244
+ _, seq_len, _ = prompt_embeds.shape
245
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
246
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
247
+
248
+ return prompt_embeds
249
+
250
+ def encode_prompt(
251
+ self,
252
+ prompt: Union[str, List[str]],
253
+ negative_prompt: Optional[Union[str, List[str]]] = None,
254
+ do_classifier_free_guidance: bool = True,
255
+ num_videos_per_prompt: int = 1,
256
+ prompt_embeds: Optional[torch.Tensor] = None,
257
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
258
+ max_sequence_length: int = 226,
259
+ device: Optional[torch.device] = None,
260
+ dtype: Optional[torch.dtype] = None,
261
+ ):
262
+ r"""
263
+ Encodes the prompt into text encoder hidden states.
264
+
265
+ Args:
266
+ prompt (`str` or `List[str]`, *optional*):
267
+ prompt to be encoded
268
+ negative_prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
270
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
271
+ less than `1`).
272
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
273
+ Whether to use classifier free guidance or not.
274
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
275
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
276
+ prompt_embeds (`torch.Tensor`, *optional*):
277
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
278
+ provided, text embeddings will be generated from `prompt` input argument.
279
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
280
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
281
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
282
+ argument.
283
+ device: (`torch.device`, *optional*):
284
+ torch device
285
+ dtype: (`torch.dtype`, *optional*):
286
+ torch dtype
287
+ """
288
+ device = device or self._execution_device
289
+
290
+ prompt = [prompt] if isinstance(prompt, str) else prompt
291
+ if prompt is not None:
292
+ batch_size = len(prompt)
293
+ else:
294
+ batch_size = prompt_embeds.shape[0]
295
+
296
+ if prompt_embeds is None:
297
+ prompt_embeds = self._get_t5_prompt_embeds(
298
+ prompt=prompt,
299
+ num_videos_per_prompt=num_videos_per_prompt,
300
+ max_sequence_length=max_sequence_length,
301
+ device=device,
302
+ dtype=dtype,
303
+ )
304
+
305
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
306
+ negative_prompt = negative_prompt or ""
307
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
308
+
309
+ if prompt is not None and type(prompt) is not type(negative_prompt):
310
+ raise TypeError(
311
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
312
+ f" {type(prompt)}."
313
+ )
314
+ elif batch_size != len(negative_prompt):
315
+ raise ValueError(
316
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
317
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
318
+ " the batch size of `prompt`."
319
+ )
320
+
321
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
322
+ prompt=negative_prompt,
323
+ num_videos_per_prompt=num_videos_per_prompt,
324
+ max_sequence_length=max_sequence_length,
325
+ device=device,
326
+ dtype=dtype,
327
+ )
328
+
329
+ return prompt_embeds, negative_prompt_embeds
330
+
331
+ def prepare_latents(
332
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
333
+ ):
334
+ shape = (
335
+ batch_size,
336
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
337
+ num_channels_latents,
338
+ height // self.vae_scale_factor_spatial,
339
+ width // self.vae_scale_factor_spatial,
340
+ )
341
+ if isinstance(generator, list) and len(generator) != batch_size:
342
+ raise ValueError(
343
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
344
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
345
+ )
346
+
347
+ if latents is None:
348
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
349
+ else:
350
+ latents = latents.to(device)
351
+
352
+ # scale the initial noise by the standard deviation required by the scheduler
353
+ latents = latents * self.scheduler.init_noise_sigma
354
+ return latents
355
+
356
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
357
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
358
+ latents = 1 / self.vae.config.scaling_factor * latents
359
+
360
+ frames = self.vae.decode(latents).sample
361
+ frames = (frames / 2 + 0.5).clamp(0, 1)
362
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
363
+ frames = frames.cpu().float().numpy()
364
+ return frames
365
+
366
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
367
+ def prepare_extra_step_kwargs(self, generator, eta):
368
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
369
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
370
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
371
+ # and should be between [0, 1]
372
+
373
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
374
+ extra_step_kwargs = {}
375
+ if accepts_eta:
376
+ extra_step_kwargs["eta"] = eta
377
+
378
+ # check if the scheduler accepts generator
379
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
380
+ if accepts_generator:
381
+ extra_step_kwargs["generator"] = generator
382
+ return extra_step_kwargs
383
+
384
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
385
+ def check_inputs(
386
+ self,
387
+ prompt,
388
+ height,
389
+ width,
390
+ negative_prompt,
391
+ callback_on_step_end_tensor_inputs,
392
+ prompt_embeds=None,
393
+ negative_prompt_embeds=None,
394
+ ):
395
+ if height % 8 != 0 or width % 8 != 0:
396
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
397
+
398
+ if callback_on_step_end_tensor_inputs is not None and not all(
399
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
400
+ ):
401
+ raise ValueError(
402
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
403
+ )
404
+ if prompt is not None and prompt_embeds is not None:
405
+ raise ValueError(
406
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
407
+ " only forward one of the two."
408
+ )
409
+ elif prompt is None and prompt_embeds is None:
410
+ raise ValueError(
411
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
412
+ )
413
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
414
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
415
+
416
+ if prompt is not None and negative_prompt_embeds is not None:
417
+ raise ValueError(
418
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
419
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
420
+ )
421
+
422
+ if negative_prompt is not None and negative_prompt_embeds is not None:
423
+ raise ValueError(
424
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
425
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
426
+ )
427
+
428
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
429
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
430
+ raise ValueError(
431
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
432
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
433
+ f" {negative_prompt_embeds.shape}."
434
+ )
435
+
436
+ def fuse_qkv_projections(self) -> None:
437
+ r"""Enables fused QKV projections."""
438
+ self.fusing_transformer = True
439
+ self.transformer.fuse_qkv_projections()
440
+
441
+ def unfuse_qkv_projections(self) -> None:
442
+ r"""Disable QKV projection fusion if enabled."""
443
+ if not self.fusing_transformer:
444
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
445
+ else:
446
+ self.transformer.unfuse_qkv_projections()
447
+ self.fusing_transformer = False
448
+
449
+ def _prepare_rotary_positional_embeddings(
450
+ self,
451
+ height: int,
452
+ width: int,
453
+ num_frames: int,
454
+ device: torch.device,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
457
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
458
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
459
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
460
+
461
+ grid_crops_coords = get_resize_crop_region_for_grid(
462
+ (grid_height, grid_width), base_size_width, base_size_height
463
+ )
464
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
465
+ embed_dim=self.transformer.config.attention_head_dim,
466
+ crops_coords=grid_crops_coords,
467
+ grid_size=(grid_height, grid_width),
468
+ temporal_size=num_frames,
469
+ use_real=True,
470
+ )
471
+
472
+ freqs_cos = freqs_cos.to(device=device)
473
+ freqs_sin = freqs_sin.to(device=device)
474
+ return freqs_cos, freqs_sin
475
+
476
+ @property
477
+ def guidance_scale(self):
478
+ return self._guidance_scale
479
+
480
+ @property
481
+ def num_timesteps(self):
482
+ return self._num_timesteps
483
+
484
+ @property
485
+ def interrupt(self):
486
+ return self._interrupt
487
+
488
+ @torch.no_grad()
489
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
490
+ def __call__(
491
+ self,
492
+ prompt: Optional[Union[str, List[str]]] = None,
493
+ negative_prompt: Optional[Union[str, List[str]]] = None,
494
+ height: int = 480,
495
+ width: int = 720,
496
+ num_frames: int = 49,
497
+ num_inference_steps: int = 50,
498
+ timesteps: Optional[List[int]] = None,
499
+ guidance_scale: float = 6,
500
+ use_dynamic_cfg: bool = False,
501
+ num_videos_per_prompt: int = 1,
502
+ eta: float = 0.0,
503
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
504
+ latents: Optional[torch.FloatTensor] = None,
505
+ prompt_embeds: Optional[torch.FloatTensor] = None,
506
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
507
+ output_type: str = "numpy",
508
+ return_dict: bool = False,
509
+ callback_on_step_end: Optional[
510
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
511
+ ] = None,
512
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
513
+ max_sequence_length: int = 226,
514
+ ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
515
+ """
516
+ Function invoked when calling the pipeline for generation.
517
+
518
+ Args:
519
+ prompt (`str` or `List[str]`, *optional*):
520
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
521
+ instead.
522
+ negative_prompt (`str` or `List[str]`, *optional*):
523
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
524
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
525
+ less than `1`).
526
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
527
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
528
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
529
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
530
+ num_frames (`int`, defaults to `48`):
531
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
532
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
533
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
534
+ needs to be satisfied is that of divisibility mentioned above.
535
+ num_inference_steps (`int`, *optional*, defaults to 50):
536
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
537
+ expense of slower inference.
538
+ timesteps (`List[int]`, *optional*):
539
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
540
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
541
+ passed will be used. Must be in descending order.
542
+ guidance_scale (`float`, *optional*, defaults to 7.0):
543
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
544
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
545
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
546
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
547
+ usually at the expense of lower image quality.
548
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
549
+ The number of videos to generate per prompt.
550
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
551
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
552
+ to make generation deterministic.
553
+ latents (`torch.FloatTensor`, *optional*):
554
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
555
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
556
+ tensor will ge generated by sampling using the supplied random `generator`.
557
+ prompt_embeds (`torch.FloatTensor`, *optional*):
558
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
559
+ provided, text embeddings will be generated from `prompt` input argument.
560
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
561
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
562
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
563
+ argument.
564
+ output_type (`str`, *optional*, defaults to `"pil"`):
565
+ The output format of the generate image. Choose between
566
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
567
+ return_dict (`bool`, *optional*, defaults to `True`):
568
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
569
+ of a plain tuple.
570
+ callback_on_step_end (`Callable`, *optional*):
571
+ A function that calls at the end of each denoising steps during the inference. The function is called
572
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
573
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
574
+ `callback_on_step_end_tensor_inputs`.
575
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
576
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
577
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
578
+ `._callback_tensor_inputs` attribute of your pipeline class.
579
+ max_sequence_length (`int`, defaults to `226`):
580
+ Maximum sequence length in encoded prompt. Must be consistent with
581
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
582
+
583
+ Examples:
584
+
585
+ Returns:
586
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
587
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
588
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
589
+ """
590
+
591
+ if num_frames > 49:
592
+ raise ValueError(
593
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
594
+ )
595
+
596
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
597
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
598
+
599
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
600
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
601
+ num_videos_per_prompt = 1
602
+
603
+ # 1. Check inputs. Raise error if not correct
604
+ self.check_inputs(
605
+ prompt,
606
+ height,
607
+ width,
608
+ negative_prompt,
609
+ callback_on_step_end_tensor_inputs,
610
+ prompt_embeds,
611
+ negative_prompt_embeds,
612
+ )
613
+ self._guidance_scale = guidance_scale
614
+ self._interrupt = False
615
+
616
+ # 2. Default call parameters
617
+ if prompt is not None and isinstance(prompt, str):
618
+ batch_size = 1
619
+ elif prompt is not None and isinstance(prompt, list):
620
+ batch_size = len(prompt)
621
+ else:
622
+ batch_size = prompt_embeds.shape[0]
623
+
624
+ device = self._execution_device
625
+
626
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
627
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
628
+ # corresponds to doing no classifier free guidance.
629
+ do_classifier_free_guidance = guidance_scale > 1.0
630
+
631
+ # 3. Encode input prompt
632
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
633
+ prompt,
634
+ negative_prompt,
635
+ do_classifier_free_guidance,
636
+ num_videos_per_prompt=num_videos_per_prompt,
637
+ prompt_embeds=prompt_embeds,
638
+ negative_prompt_embeds=negative_prompt_embeds,
639
+ max_sequence_length=max_sequence_length,
640
+ device=device,
641
+ )
642
+ if do_classifier_free_guidance:
643
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
644
+
645
+ # 4. Prepare timesteps
646
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
647
+ self._num_timesteps = len(timesteps)
648
+
649
+ # 5. Prepare latents.
650
+ latent_channels = self.transformer.config.in_channels
651
+ latents = self.prepare_latents(
652
+ batch_size * num_videos_per_prompt,
653
+ latent_channels,
654
+ num_frames,
655
+ height,
656
+ width,
657
+ prompt_embeds.dtype,
658
+ device,
659
+ generator,
660
+ latents,
661
+ )
662
+
663
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
664
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
665
+
666
+ # 7. Create rotary embeds if required
667
+ image_rotary_emb = (
668
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
669
+ if self.transformer.config.use_rotary_positional_embeddings
670
+ else None
671
+ )
672
+
673
+ # 8. Denoising loop
674
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
675
+
676
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
677
+ # for DPM-solver++
678
+ old_pred_original_sample = None
679
+ for i, t in enumerate(timesteps):
680
+ if self.interrupt:
681
+ continue
682
+
683
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
684
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
685
+
686
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
687
+ timestep = t.expand(latent_model_input.shape[0])
688
+
689
+ # predict noise model_output
690
+ noise_pred = self.transformer(
691
+ hidden_states=latent_model_input,
692
+ encoder_hidden_states=prompt_embeds,
693
+ timestep=timestep,
694
+ image_rotary_emb=image_rotary_emb,
695
+ return_dict=False,
696
+ )[0]
697
+ noise_pred = noise_pred.float()
698
+
699
+ # perform guidance
700
+ if use_dynamic_cfg:
701
+ self._guidance_scale = 1 + guidance_scale * (
702
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
703
+ )
704
+ if do_classifier_free_guidance:
705
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
706
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
707
+
708
+ # compute the previous noisy sample x_t -> x_t-1
709
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
710
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
711
+ else:
712
+ latents, old_pred_original_sample = self.scheduler.step(
713
+ noise_pred,
714
+ old_pred_original_sample,
715
+ t,
716
+ timesteps[i - 1] if i > 0 else None,
717
+ latents,
718
+ **extra_step_kwargs,
719
+ return_dict=False,
720
+ )
721
+ latents = latents.to(prompt_embeds.dtype)
722
+
723
+ # call the callback, if provided
724
+ if callback_on_step_end is not None:
725
+ callback_kwargs = {}
726
+ for k in callback_on_step_end_tensor_inputs:
727
+ callback_kwargs[k] = locals()[k]
728
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
729
+
730
+ latents = callback_outputs.pop("latents", latents)
731
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
732
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
733
+
734
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
735
+ progress_bar.update()
736
+
737
+ if output_type == "numpy":
738
+ video = self.decode_latents(latents)
739
+ elif not output_type == "latent":
740
+ video = self.decode_latents(latents)
741
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
742
+ else:
743
+ video = latents
744
+
745
+ # Offload all models
746
+ self.maybe_free_model_hooks()
747
+
748
+ if not return_dict:
749
+ video = torch.from_numpy(video)
750
+
751
+ return CogVideoX_Fun_PipelineOutput(videos=video)
cogvideox/pipeline/pipeline_cogvideox_control.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange
24
+ from transformers import T5EncoderModel, T5Tokenizer
25
+
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
28
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
31
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.video_processor import VideoProcessor
34
+ from diffusers.image_processor import VaeImageProcessor
35
+ from einops import rearrange
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ >>> import torch
45
+ >>> from diffusers import CogVideoX_Fun_Pipeline
46
+ >>> from diffusers.utils import export_to_video
47
+
48
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
49
+ >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
50
+ >>> prompt = (
51
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
52
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
53
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
54
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
55
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
56
+ ... "atmosphere of this unique musical performance."
57
+ ... )
58
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
59
+ >>> export_to_video(video, "output.mp4", fps=8)
60
+ ```
61
+ """
62
+
63
+
64
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
65
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
66
+ tw = tgt_width
67
+ th = tgt_height
68
+ h, w = src
69
+ r = h / w
70
+ if r > (th / tw):
71
+ resize_height = th
72
+ resize_width = int(round(th / h * w))
73
+ else:
74
+ resize_width = tw
75
+ resize_height = int(round(tw / w * h))
76
+
77
+ crop_top = int(round((th - resize_height) / 2.0))
78
+ crop_left = int(round((tw - resize_width) / 2.0))
79
+
80
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
81
+
82
+
83
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
84
+ def retrieve_timesteps(
85
+ scheduler,
86
+ num_inference_steps: Optional[int] = None,
87
+ device: Optional[Union[str, torch.device]] = None,
88
+ timesteps: Optional[List[int]] = None,
89
+ sigmas: Optional[List[float]] = None,
90
+ **kwargs,
91
+ ):
92
+ """
93
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
94
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
95
+
96
+ Args:
97
+ scheduler (`SchedulerMixin`):
98
+ The scheduler to get timesteps from.
99
+ num_inference_steps (`int`):
100
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
101
+ must be `None`.
102
+ device (`str` or `torch.device`, *optional*):
103
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
104
+ timesteps (`List[int]`, *optional*):
105
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
106
+ `num_inference_steps` and `sigmas` must be `None`.
107
+ sigmas (`List[float]`, *optional*):
108
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
109
+ `num_inference_steps` and `timesteps` must be `None`.
110
+
111
+ Returns:
112
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
113
+ second element is the number of inference steps.
114
+ """
115
+ if timesteps is not None and sigmas is not None:
116
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
117
+ if timesteps is not None:
118
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ if not accepts_timesteps:
120
+ raise ValueError(
121
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ f" timestep schedules. Please check whether you are using the correct scheduler."
123
+ )
124
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ elif sigmas is not None:
128
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
129
+ if not accept_sigmas:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ num_inference_steps = len(timesteps)
137
+ else:
138
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ return timesteps, num_inference_steps
141
+
142
+
143
+ @dataclass
144
+ class CogVideoX_Fun_PipelineOutput(BaseOutput):
145
+ r"""
146
+ Output class for CogVideo pipelines.
147
+
148
+ Args:
149
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
150
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
151
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
152
+ `(batch_size, num_frames, channels, height, width)`.
153
+ """
154
+
155
+ videos: torch.Tensor
156
+
157
+
158
+ class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
159
+ r"""
160
+ Pipeline for text-to-video generation using CogVideoX.
161
+
162
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
163
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
164
+
165
+ Args:
166
+ vae ([`AutoencoderKL`]):
167
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
168
+ text_encoder ([`T5EncoderModel`]):
169
+ Frozen text-encoder. CogVideoX_Fun uses
170
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
171
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
172
+ tokenizer (`T5Tokenizer`):
173
+ Tokenizer of class
174
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
175
+ transformer ([`CogVideoXTransformer3DModel`]):
176
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
177
+ scheduler ([`SchedulerMixin`]):
178
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
179
+ """
180
+
181
+ _optional_components = []
182
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
183
+
184
+ _callback_tensor_inputs = [
185
+ "latents",
186
+ "prompt_embeds",
187
+ "negative_prompt_embeds",
188
+ ]
189
+
190
+ def __init__(
191
+ self,
192
+ tokenizer: T5Tokenizer,
193
+ text_encoder: T5EncoderModel,
194
+ vae: AutoencoderKLCogVideoX,
195
+ transformer: CogVideoXTransformer3DModel,
196
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
197
+ ):
198
+ super().__init__()
199
+
200
+ self.register_modules(
201
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
202
+ )
203
+ self.vae_scale_factor_spatial = (
204
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
205
+ )
206
+ self.vae_scale_factor_temporal = (
207
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
208
+ )
209
+
210
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
211
+
212
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
213
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
214
+ self.mask_processor = VaeImageProcessor(
215
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
216
+ )
217
+
218
+ def _get_t5_prompt_embeds(
219
+ self,
220
+ prompt: Union[str, List[str]] = None,
221
+ num_videos_per_prompt: int = 1,
222
+ max_sequence_length: int = 226,
223
+ device: Optional[torch.device] = None,
224
+ dtype: Optional[torch.dtype] = None,
225
+ ):
226
+ device = device or self._execution_device
227
+ dtype = dtype or self.text_encoder.dtype
228
+
229
+ prompt = [prompt] if isinstance(prompt, str) else prompt
230
+ batch_size = len(prompt)
231
+
232
+ text_inputs = self.tokenizer(
233
+ prompt,
234
+ padding="max_length",
235
+ max_length=max_sequence_length,
236
+ truncation=True,
237
+ add_special_tokens=True,
238
+ return_tensors="pt",
239
+ )
240
+ text_input_ids = text_inputs.input_ids
241
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
242
+
243
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
244
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
245
+ logger.warning(
246
+ "The following part of your input was truncated because `max_sequence_length` is set to "
247
+ f" {max_sequence_length} tokens: {removed_text}"
248
+ )
249
+
250
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
251
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
252
+
253
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
254
+ _, seq_len, _ = prompt_embeds.shape
255
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
256
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
257
+
258
+ return prompt_embeds
259
+
260
+ def encode_prompt(
261
+ self,
262
+ prompt: Union[str, List[str]],
263
+ negative_prompt: Optional[Union[str, List[str]]] = None,
264
+ do_classifier_free_guidance: bool = True,
265
+ num_videos_per_prompt: int = 1,
266
+ prompt_embeds: Optional[torch.Tensor] = None,
267
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
268
+ max_sequence_length: int = 226,
269
+ device: Optional[torch.device] = None,
270
+ dtype: Optional[torch.dtype] = None,
271
+ ):
272
+ r"""
273
+ Encodes the prompt into text encoder hidden states.
274
+
275
+ Args:
276
+ prompt (`str` or `List[str]`, *optional*):
277
+ prompt to be encoded
278
+ negative_prompt (`str` or `List[str]`, *optional*):
279
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
280
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
281
+ less than `1`).
282
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
283
+ Whether to use classifier free guidance or not.
284
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
285
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
286
+ prompt_embeds (`torch.Tensor`, *optional*):
287
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
288
+ provided, text embeddings will be generated from `prompt` input argument.
289
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
290
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
292
+ argument.
293
+ device: (`torch.device`, *optional*):
294
+ torch device
295
+ dtype: (`torch.dtype`, *optional*):
296
+ torch dtype
297
+ """
298
+ device = device or self._execution_device
299
+
300
+ prompt = [prompt] if isinstance(prompt, str) else prompt
301
+ if prompt is not None:
302
+ batch_size = len(prompt)
303
+ else:
304
+ batch_size = prompt_embeds.shape[0]
305
+
306
+ if prompt_embeds is None:
307
+ prompt_embeds = self._get_t5_prompt_embeds(
308
+ prompt=prompt,
309
+ num_videos_per_prompt=num_videos_per_prompt,
310
+ max_sequence_length=max_sequence_length,
311
+ device=device,
312
+ dtype=dtype,
313
+ )
314
+
315
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
316
+ negative_prompt = negative_prompt or ""
317
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
318
+
319
+ if prompt is not None and type(prompt) is not type(negative_prompt):
320
+ raise TypeError(
321
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
322
+ f" {type(prompt)}."
323
+ )
324
+ elif batch_size != len(negative_prompt):
325
+ raise ValueError(
326
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
327
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
328
+ " the batch size of `prompt`."
329
+ )
330
+
331
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
332
+ prompt=negative_prompt,
333
+ num_videos_per_prompt=num_videos_per_prompt,
334
+ max_sequence_length=max_sequence_length,
335
+ device=device,
336
+ dtype=dtype,
337
+ )
338
+
339
+ return prompt_embeds, negative_prompt_embeds
340
+
341
+ def prepare_latents(
342
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
343
+ ):
344
+ shape = (
345
+ batch_size,
346
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
347
+ num_channels_latents,
348
+ height // self.vae_scale_factor_spatial,
349
+ width // self.vae_scale_factor_spatial,
350
+ )
351
+ if isinstance(generator, list) and len(generator) != batch_size:
352
+ raise ValueError(
353
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
354
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
355
+ )
356
+
357
+ if latents is None:
358
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
359
+ else:
360
+ latents = latents.to(device)
361
+
362
+ # scale the initial noise by the standard deviation required by the scheduler
363
+ latents = latents * self.scheduler.init_noise_sigma
364
+ return latents
365
+
366
+ def prepare_control_latents(
367
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
368
+ ):
369
+ # resize the mask to latents shape as we concatenate the mask to the latents
370
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
371
+ # and half precision
372
+
373
+ if mask is not None:
374
+ mask = mask.to(device=device, dtype=self.vae.dtype)
375
+ bs = 1
376
+ new_mask = []
377
+ for i in range(0, mask.shape[0], bs):
378
+ mask_bs = mask[i : i + bs]
379
+ mask_bs = self.vae.encode(mask_bs)[0]
380
+ mask_bs = mask_bs.mode()
381
+ new_mask.append(mask_bs)
382
+ mask = torch.cat(new_mask, dim = 0)
383
+ mask = mask * self.vae.config.scaling_factor
384
+
385
+ if masked_image is not None:
386
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
387
+ bs = 1
388
+ new_mask_pixel_values = []
389
+ for i in range(0, masked_image.shape[0], bs):
390
+ mask_pixel_values_bs = masked_image[i : i + bs]
391
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
392
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
393
+ new_mask_pixel_values.append(mask_pixel_values_bs)
394
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
395
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
396
+ else:
397
+ masked_image_latents = None
398
+
399
+ return mask, masked_image_latents
400
+
401
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
402
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
403
+ latents = 1 / self.vae.config.scaling_factor * latents
404
+
405
+ frames = self.vae.decode(latents).sample
406
+ frames = (frames / 2 + 0.5).clamp(0, 1)
407
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
408
+ frames = frames.cpu().float().numpy()
409
+ return frames
410
+
411
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
412
+ def prepare_extra_step_kwargs(self, generator, eta):
413
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
414
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
415
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
416
+ # and should be between [0, 1]
417
+
418
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
419
+ extra_step_kwargs = {}
420
+ if accepts_eta:
421
+ extra_step_kwargs["eta"] = eta
422
+
423
+ # check if the scheduler accepts generator
424
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
425
+ if accepts_generator:
426
+ extra_step_kwargs["generator"] = generator
427
+ return extra_step_kwargs
428
+
429
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
430
+ def check_inputs(
431
+ self,
432
+ prompt,
433
+ height,
434
+ width,
435
+ negative_prompt,
436
+ callback_on_step_end_tensor_inputs,
437
+ prompt_embeds=None,
438
+ negative_prompt_embeds=None,
439
+ ):
440
+ if height % 8 != 0 or width % 8 != 0:
441
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
442
+
443
+ if callback_on_step_end_tensor_inputs is not None and not all(
444
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
445
+ ):
446
+ raise ValueError(
447
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
448
+ )
449
+ if prompt is not None and prompt_embeds is not None:
450
+ raise ValueError(
451
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
452
+ " only forward one of the two."
453
+ )
454
+ elif prompt is None and prompt_embeds is None:
455
+ raise ValueError(
456
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
457
+ )
458
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
459
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
460
+
461
+ if prompt is not None and negative_prompt_embeds is not None:
462
+ raise ValueError(
463
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
464
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
465
+ )
466
+
467
+ if negative_prompt is not None and negative_prompt_embeds is not None:
468
+ raise ValueError(
469
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
470
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
471
+ )
472
+
473
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
474
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
475
+ raise ValueError(
476
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
477
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
478
+ f" {negative_prompt_embeds.shape}."
479
+ )
480
+
481
+ def fuse_qkv_projections(self) -> None:
482
+ r"""Enables fused QKV projections."""
483
+ self.fusing_transformer = True
484
+ self.transformer.fuse_qkv_projections()
485
+
486
+ def unfuse_qkv_projections(self) -> None:
487
+ r"""Disable QKV projection fusion if enabled."""
488
+ if not self.fusing_transformer:
489
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
490
+ else:
491
+ self.transformer.unfuse_qkv_projections()
492
+ self.fusing_transformer = False
493
+
494
+ def _prepare_rotary_positional_embeddings(
495
+ self,
496
+ height: int,
497
+ width: int,
498
+ num_frames: int,
499
+ device: torch.device,
500
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
501
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
502
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
503
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
504
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
505
+
506
+ grid_crops_coords = get_resize_crop_region_for_grid(
507
+ (grid_height, grid_width), base_size_width, base_size_height
508
+ )
509
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
510
+ embed_dim=self.transformer.config.attention_head_dim,
511
+ crops_coords=grid_crops_coords,
512
+ grid_size=(grid_height, grid_width),
513
+ temporal_size=num_frames,
514
+ use_real=True,
515
+ )
516
+
517
+ freqs_cos = freqs_cos.to(device=device)
518
+ freqs_sin = freqs_sin.to(device=device)
519
+ return freqs_cos, freqs_sin
520
+
521
+ @property
522
+ def guidance_scale(self):
523
+ return self._guidance_scale
524
+
525
+ @property
526
+ def num_timesteps(self):
527
+ return self._num_timesteps
528
+
529
+ @property
530
+ def interrupt(self):
531
+ return self._interrupt
532
+
533
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
534
+ def get_timesteps(self, num_inference_steps, strength, device):
535
+ # get the original timestep using init_timestep
536
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
537
+
538
+ t_start = max(num_inference_steps - init_timestep, 0)
539
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
540
+
541
+ return timesteps, num_inference_steps - t_start
542
+
543
+ @torch.no_grad()
544
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
545
+ def __call__(
546
+ self,
547
+ prompt: Optional[Union[str, List[str]]] = None,
548
+ negative_prompt: Optional[Union[str, List[str]]] = None,
549
+ height: int = 480,
550
+ width: int = 720,
551
+ video: Union[torch.FloatTensor] = None,
552
+ control_video: Union[torch.FloatTensor] = None,
553
+ num_frames: int = 49,
554
+ num_inference_steps: int = 50,
555
+ timesteps: Optional[List[int]] = None,
556
+ guidance_scale: float = 6,
557
+ use_dynamic_cfg: bool = False,
558
+ num_videos_per_prompt: int = 1,
559
+ eta: float = 0.0,
560
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
561
+ latents: Optional[torch.FloatTensor] = None,
562
+ prompt_embeds: Optional[torch.FloatTensor] = None,
563
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
564
+ output_type: str = "numpy",
565
+ return_dict: bool = False,
566
+ callback_on_step_end: Optional[
567
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
568
+ ] = None,
569
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
570
+ max_sequence_length: int = 226,
571
+ comfyui_progressbar: bool = False,
572
+ ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
573
+ """
574
+ Function invoked when calling the pipeline for generation.
575
+
576
+ Args:
577
+ prompt (`str` or `List[str]`, *optional*):
578
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
579
+ instead.
580
+ negative_prompt (`str` or `List[str]`, *optional*):
581
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
582
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
583
+ less than `1`).
584
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
585
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
586
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
587
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
588
+ num_frames (`int`, defaults to `48`):
589
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
590
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
591
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
592
+ needs to be satisfied is that of divisibility mentioned above.
593
+ num_inference_steps (`int`, *optional*, defaults to 50):
594
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
595
+ expense of slower inference.
596
+ timesteps (`List[int]`, *optional*):
597
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
598
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
599
+ passed will be used. Must be in descending order.
600
+ guidance_scale (`float`, *optional*, defaults to 7.0):
601
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
602
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
603
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
604
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
605
+ usually at the expense of lower image quality.
606
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
607
+ The number of videos to generate per prompt.
608
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
609
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
610
+ to make generation deterministic.
611
+ latents (`torch.FloatTensor`, *optional*):
612
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
613
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
614
+ tensor will ge generated by sampling using the supplied random `generator`.
615
+ prompt_embeds (`torch.FloatTensor`, *optional*):
616
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
617
+ provided, text embeddings will be generated from `prompt` input argument.
618
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
619
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
620
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
621
+ argument.
622
+ output_type (`str`, *optional*, defaults to `"pil"`):
623
+ The output format of the generate image. Choose between
624
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
625
+ return_dict (`bool`, *optional*, defaults to `True`):
626
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
627
+ of a plain tuple.
628
+ callback_on_step_end (`Callable`, *optional*):
629
+ A function that calls at the end of each denoising steps during the inference. The function is called
630
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
631
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
632
+ `callback_on_step_end_tensor_inputs`.
633
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
634
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
635
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
636
+ `._callback_tensor_inputs` attribute of your pipeline class.
637
+ max_sequence_length (`int`, defaults to `226`):
638
+ Maximum sequence length in encoded prompt. Must be consistent with
639
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
640
+
641
+ Examples:
642
+
643
+ Returns:
644
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
645
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
646
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
647
+ """
648
+
649
+ if num_frames > 49:
650
+ raise ValueError(
651
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
652
+ )
653
+
654
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
655
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
656
+
657
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
658
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
659
+ num_videos_per_prompt = 1
660
+
661
+ # 1. Check inputs. Raise error if not correct
662
+ self.check_inputs(
663
+ prompt,
664
+ height,
665
+ width,
666
+ negative_prompt,
667
+ callback_on_step_end_tensor_inputs,
668
+ prompt_embeds,
669
+ negative_prompt_embeds,
670
+ )
671
+ self._guidance_scale = guidance_scale
672
+ self._interrupt = False
673
+
674
+ # 2. Default call parameters
675
+ if prompt is not None and isinstance(prompt, str):
676
+ batch_size = 1
677
+ elif prompt is not None and isinstance(prompt, list):
678
+ batch_size = len(prompt)
679
+ else:
680
+ batch_size = prompt_embeds.shape[0]
681
+
682
+ device = self._execution_device
683
+
684
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
685
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
686
+ # corresponds to doing no classifier free guidance.
687
+ do_classifier_free_guidance = guidance_scale > 1.0
688
+
689
+ # 3. Encode input prompt
690
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
691
+ prompt,
692
+ negative_prompt,
693
+ do_classifier_free_guidance,
694
+ num_videos_per_prompt=num_videos_per_prompt,
695
+ prompt_embeds=prompt_embeds,
696
+ negative_prompt_embeds=negative_prompt_embeds,
697
+ max_sequence_length=max_sequence_length,
698
+ device=device,
699
+ )
700
+ if do_classifier_free_guidance:
701
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
702
+
703
+ # 4. Prepare timesteps
704
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
705
+ self._num_timesteps = len(timesteps)
706
+ if comfyui_progressbar:
707
+ from comfy.utils import ProgressBar
708
+ pbar = ProgressBar(num_inference_steps + 2)
709
+
710
+ # 5. Prepare latents.
711
+ latent_channels = self.vae.config.latent_channels
712
+ latents = self.prepare_latents(
713
+ batch_size * num_videos_per_prompt,
714
+ latent_channels,
715
+ num_frames,
716
+ height,
717
+ width,
718
+ prompt_embeds.dtype,
719
+ device,
720
+ generator,
721
+ latents,
722
+ )
723
+ if comfyui_progressbar:
724
+ pbar.update(1)
725
+
726
+ if control_video is not None:
727
+ video_length = control_video.shape[2]
728
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
729
+ control_video = control_video.to(dtype=torch.float32)
730
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
731
+ else:
732
+ control_video = None
733
+ control_video_latents = self.prepare_control_latents(
734
+ None,
735
+ control_video,
736
+ batch_size,
737
+ height,
738
+ width,
739
+ prompt_embeds.dtype,
740
+ device,
741
+ generator,
742
+ do_classifier_free_guidance
743
+ )[1]
744
+ control_video_latents_input = (
745
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
746
+ )
747
+ control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
748
+
749
+ if comfyui_progressbar:
750
+ pbar.update(1)
751
+
752
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
753
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
754
+
755
+ # 7. Create rotary embeds if required
756
+ image_rotary_emb = (
757
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
758
+ if self.transformer.config.use_rotary_positional_embeddings
759
+ else None
760
+ )
761
+
762
+ # 8. Denoising loop
763
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
764
+
765
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
766
+ # for DPM-solver++
767
+ old_pred_original_sample = None
768
+ for i, t in enumerate(timesteps):
769
+ if self.interrupt:
770
+ continue
771
+
772
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
773
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
774
+
775
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
776
+ timestep = t.expand(latent_model_input.shape[0])
777
+
778
+ # predict noise model_output
779
+ noise_pred = self.transformer(
780
+ hidden_states=latent_model_input,
781
+ encoder_hidden_states=prompt_embeds,
782
+ timestep=timestep,
783
+ image_rotary_emb=image_rotary_emb,
784
+ return_dict=False,
785
+ control_latents=control_latents,
786
+ )[0]
787
+ noise_pred = noise_pred.float()
788
+
789
+ # perform guidance
790
+ if use_dynamic_cfg:
791
+ self._guidance_scale = 1 + guidance_scale * (
792
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
793
+ )
794
+ if do_classifier_free_guidance:
795
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
796
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
797
+
798
+ # compute the previous noisy sample x_t -> x_t-1
799
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
800
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
801
+ else:
802
+ latents, old_pred_original_sample = self.scheduler.step(
803
+ noise_pred,
804
+ old_pred_original_sample,
805
+ t,
806
+ timesteps[i - 1] if i > 0 else None,
807
+ latents,
808
+ **extra_step_kwargs,
809
+ return_dict=False,
810
+ )
811
+ latents = latents.to(prompt_embeds.dtype)
812
+
813
+ # call the callback, if provided
814
+ if callback_on_step_end is not None:
815
+ callback_kwargs = {}
816
+ for k in callback_on_step_end_tensor_inputs:
817
+ callback_kwargs[k] = locals()[k]
818
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
819
+
820
+ latents = callback_outputs.pop("latents", latents)
821
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
822
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
823
+
824
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
825
+ progress_bar.update()
826
+ if comfyui_progressbar:
827
+ pbar.update(1)
828
+
829
+ if output_type == "numpy":
830
+ video = self.decode_latents(latents)
831
+ elif not output_type == "latent":
832
+ video = self.decode_latents(latents)
833
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
834
+ else:
835
+ video = latents
836
+
837
+ # Offload all models
838
+ self.maybe_free_model_hooks()
839
+
840
+ if not return_dict:
841
+ video = torch.from_numpy(video)
842
+
843
+ return CogVideoX_Fun_PipelineOutput(videos=video)
cogvideox/pipeline/pipeline_cogvideox_inpaint.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange
24
+ from transformers import T5EncoderModel, T5Tokenizer
25
+
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
28
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
31
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.video_processor import VideoProcessor
34
+ from diffusers.image_processor import VaeImageProcessor
35
+ from einops import rearrange
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ >>> import torch
45
+ >>> from diffusers import CogVideoX_Fun_Pipeline
46
+ >>> from diffusers.utils import export_to_video
47
+
48
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
49
+ >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
50
+ >>> prompt = (
51
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
52
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
53
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
54
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
55
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
56
+ ... "atmosphere of this unique musical performance."
57
+ ... )
58
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
59
+ >>> export_to_video(video, "output.mp4", fps=8)
60
+ ```
61
+ """
62
+
63
+
64
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
65
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
66
+ tw = tgt_width
67
+ th = tgt_height
68
+ h, w = src
69
+ r = h / w
70
+ if r > (th / tw):
71
+ resize_height = th
72
+ resize_width = int(round(th / h * w))
73
+ else:
74
+ resize_width = tw
75
+ resize_height = int(round(tw / w * h))
76
+
77
+ crop_top = int(round((th - resize_height) / 2.0))
78
+ crop_left = int(round((tw - resize_width) / 2.0))
79
+
80
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
81
+
82
+
83
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
84
+ def retrieve_timesteps(
85
+ scheduler,
86
+ num_inference_steps: Optional[int] = None,
87
+ device: Optional[Union[str, torch.device]] = None,
88
+ timesteps: Optional[List[int]] = None,
89
+ sigmas: Optional[List[float]] = None,
90
+ **kwargs,
91
+ ):
92
+ """
93
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
94
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
95
+
96
+ Args:
97
+ scheduler (`SchedulerMixin`):
98
+ The scheduler to get timesteps from.
99
+ num_inference_steps (`int`):
100
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
101
+ must be `None`.
102
+ device (`str` or `torch.device`, *optional*):
103
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
104
+ timesteps (`List[int]`, *optional*):
105
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
106
+ `num_inference_steps` and `sigmas` must be `None`.
107
+ sigmas (`List[float]`, *optional*):
108
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
109
+ `num_inference_steps` and `timesteps` must be `None`.
110
+
111
+ Returns:
112
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
113
+ second element is the number of inference steps.
114
+ """
115
+ if timesteps is not None and sigmas is not None:
116
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
117
+ if timesteps is not None:
118
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ if not accepts_timesteps:
120
+ raise ValueError(
121
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ f" timestep schedules. Please check whether you are using the correct scheduler."
123
+ )
124
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ elif sigmas is not None:
128
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
129
+ if not accept_sigmas:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ num_inference_steps = len(timesteps)
137
+ else:
138
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ return timesteps, num_inference_steps
141
+
142
+
143
+ def resize_mask(mask, latent, process_first_frame_only=True):
144
+ latent_size = latent.size()
145
+ batch_size, channels, num_frames, height, width = mask.shape
146
+
147
+ if process_first_frame_only:
148
+ target_size = list(latent_size[2:])
149
+ target_size[0] = 1
150
+ first_frame_resized = F.interpolate(
151
+ mask[:, :, 0:1, :, :],
152
+ size=target_size,
153
+ mode='trilinear',
154
+ align_corners=False
155
+ )
156
+
157
+ target_size = list(latent_size[2:])
158
+ target_size[0] = target_size[0] - 1
159
+ if target_size[0] != 0:
160
+ remaining_frames_resized = F.interpolate(
161
+ mask[:, :, 1:, :, :],
162
+ size=target_size,
163
+ mode='trilinear',
164
+ align_corners=False
165
+ )
166
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
167
+ else:
168
+ resized_mask = first_frame_resized
169
+ else:
170
+ target_size = list(latent_size[2:])
171
+ resized_mask = F.interpolate(
172
+ mask,
173
+ size=target_size,
174
+ mode='trilinear',
175
+ align_corners=False
176
+ )
177
+ return resized_mask
178
+
179
+
180
+ def add_noise_to_reference_video(image, ratio=None):
181
+ if ratio is None:
182
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
183
+ sigma = torch.exp(sigma).to(image.dtype)
184
+ else:
185
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
186
+
187
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
188
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
189
+ image = image + image_noise
190
+ return image
191
+
192
+
193
+ @dataclass
194
+ class CogVideoX_Fun_PipelineOutput(BaseOutput):
195
+ r"""
196
+ Output class for CogVideo pipelines.
197
+
198
+ Args:
199
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
200
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
201
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
202
+ `(batch_size, num_frames, channels, height, width)`.
203
+ """
204
+
205
+ videos: torch.Tensor
206
+
207
+
208
+ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
209
+ r"""
210
+ Pipeline for text-to-video generation using CogVideoX.
211
+
212
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
213
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
214
+
215
+ Args:
216
+ vae ([`AutoencoderKL`]):
217
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
218
+ text_encoder ([`T5EncoderModel`]):
219
+ Frozen text-encoder. CogVideoX_Fun uses
220
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
221
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
222
+ tokenizer (`T5Tokenizer`):
223
+ Tokenizer of class
224
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
225
+ transformer ([`CogVideoXTransformer3DModel`]):
226
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
227
+ scheduler ([`SchedulerMixin`]):
228
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
229
+ """
230
+
231
+ _optional_components = []
232
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
233
+
234
+ _callback_tensor_inputs = [
235
+ "latents",
236
+ "prompt_embeds",
237
+ "negative_prompt_embeds",
238
+ ]
239
+
240
+ def __init__(
241
+ self,
242
+ tokenizer: T5Tokenizer,
243
+ text_encoder: T5EncoderModel,
244
+ vae: AutoencoderKLCogVideoX,
245
+ transformer: CogVideoXTransformer3DModel,
246
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
247
+ ):
248
+ super().__init__()
249
+
250
+ self.register_modules(
251
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
252
+ )
253
+ self.vae_scale_factor_spatial = (
254
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
255
+ )
256
+ self.vae_scale_factor_temporal = (
257
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
258
+ )
259
+
260
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
261
+
262
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
263
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
264
+ self.mask_processor = VaeImageProcessor(
265
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
266
+ )
267
+
268
+ def _get_t5_prompt_embeds(
269
+ self,
270
+ prompt: Union[str, List[str]] = None,
271
+ num_videos_per_prompt: int = 1,
272
+ max_sequence_length: int = 226,
273
+ device: Optional[torch.device] = None,
274
+ dtype: Optional[torch.dtype] = None,
275
+ ):
276
+ device = device or self._execution_device
277
+ dtype = dtype or self.text_encoder.dtype
278
+
279
+ prompt = [prompt] if isinstance(prompt, str) else prompt
280
+ batch_size = len(prompt)
281
+
282
+ text_inputs = self.tokenizer(
283
+ prompt,
284
+ padding="max_length",
285
+ max_length=max_sequence_length,
286
+ truncation=True,
287
+ add_special_tokens=True,
288
+ return_tensors="pt",
289
+ )
290
+ text_input_ids = text_inputs.input_ids
291
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
292
+
293
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
294
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
295
+ logger.warning(
296
+ "The following part of your input was truncated because `max_sequence_length` is set to "
297
+ f" {max_sequence_length} tokens: {removed_text}"
298
+ )
299
+
300
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
301
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
302
+
303
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
304
+ _, seq_len, _ = prompt_embeds.shape
305
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
306
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
307
+
308
+ return prompt_embeds
309
+
310
+ def encode_prompt(
311
+ self,
312
+ prompt: Union[str, List[str]],
313
+ negative_prompt: Optional[Union[str, List[str]]] = None,
314
+ do_classifier_free_guidance: bool = True,
315
+ num_videos_per_prompt: int = 1,
316
+ prompt_embeds: Optional[torch.Tensor] = None,
317
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
318
+ max_sequence_length: int = 226,
319
+ device: Optional[torch.device] = None,
320
+ dtype: Optional[torch.dtype] = None,
321
+ ):
322
+ r"""
323
+ Encodes the prompt into text encoder hidden states.
324
+
325
+ Args:
326
+ prompt (`str` or `List[str]`, *optional*):
327
+ prompt to be encoded
328
+ negative_prompt (`str` or `List[str]`, *optional*):
329
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
330
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
331
+ less than `1`).
332
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
333
+ Whether to use classifier free guidance or not.
334
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
335
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
336
+ prompt_embeds (`torch.Tensor`, *optional*):
337
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
338
+ provided, text embeddings will be generated from `prompt` input argument.
339
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
340
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
341
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
342
+ argument.
343
+ device: (`torch.device`, *optional*):
344
+ torch device
345
+ dtype: (`torch.dtype`, *optional*):
346
+ torch dtype
347
+ """
348
+ device = device or self._execution_device
349
+
350
+ prompt = [prompt] if isinstance(prompt, str) else prompt
351
+ if prompt is not None:
352
+ batch_size = len(prompt)
353
+ else:
354
+ batch_size = prompt_embeds.shape[0]
355
+
356
+ if prompt_embeds is None:
357
+ prompt_embeds = self._get_t5_prompt_embeds(
358
+ prompt=prompt,
359
+ num_videos_per_prompt=num_videos_per_prompt,
360
+ max_sequence_length=max_sequence_length,
361
+ device=device,
362
+ dtype=dtype,
363
+ )
364
+
365
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
366
+ negative_prompt = negative_prompt or ""
367
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
368
+
369
+ if prompt is not None and type(prompt) is not type(negative_prompt):
370
+ raise TypeError(
371
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
372
+ f" {type(prompt)}."
373
+ )
374
+ elif batch_size != len(negative_prompt):
375
+ raise ValueError(
376
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
377
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
378
+ " the batch size of `prompt`."
379
+ )
380
+
381
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
382
+ prompt=negative_prompt,
383
+ num_videos_per_prompt=num_videos_per_prompt,
384
+ max_sequence_length=max_sequence_length,
385
+ device=device,
386
+ dtype=dtype,
387
+ )
388
+
389
+ return prompt_embeds, negative_prompt_embeds
390
+
391
+ def prepare_latents(
392
+ self,
393
+ batch_size,
394
+ num_channels_latents,
395
+ height,
396
+ width,
397
+ video_length,
398
+ dtype,
399
+ device,
400
+ generator,
401
+ latents=None,
402
+ video=None,
403
+ timestep=None,
404
+ is_strength_max=True,
405
+ return_noise=False,
406
+ return_video_latents=False,
407
+ ):
408
+ shape = (
409
+ batch_size,
410
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
411
+ num_channels_latents,
412
+ height // self.vae_scale_factor_spatial,
413
+ width // self.vae_scale_factor_spatial,
414
+ )
415
+ if isinstance(generator, list) and len(generator) != batch_size:
416
+ raise ValueError(
417
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
418
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
419
+ )
420
+
421
+ if return_video_latents or (latents is None and not is_strength_max):
422
+ video = video.to(device=device, dtype=self.vae.dtype)
423
+
424
+ bs = 1
425
+ new_video = []
426
+ for i in range(0, video.shape[0], bs):
427
+ video_bs = video[i : i + bs]
428
+ video_bs = self.vae.encode(video_bs)[0]
429
+ video_bs = video_bs.sample()
430
+ new_video.append(video_bs)
431
+ video = torch.cat(new_video, dim = 0)
432
+ video = video * self.vae.config.scaling_factor
433
+
434
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
435
+ video_latents = video_latents.to(device=device, dtype=dtype)
436
+ video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
437
+
438
+ if latents is None:
439
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
440
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
441
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
442
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
443
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
444
+ else:
445
+ noise = latents.to(device)
446
+ latents = noise * self.scheduler.init_noise_sigma
447
+
448
+ # scale the initial noise by the standard deviation required by the scheduler
449
+ outputs = (latents,)
450
+
451
+ if return_noise:
452
+ outputs += (noise,)
453
+
454
+ if return_video_latents:
455
+ outputs += (video_latents,)
456
+
457
+ return outputs
458
+
459
+ def prepare_mask_latents(
460
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
461
+ ):
462
+ # resize the mask to latents shape as we concatenate the mask to the latents
463
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
464
+ # and half precision
465
+
466
+ if mask is not None:
467
+ mask = mask.to(device=device, dtype=self.vae.dtype)
468
+ bs = 1
469
+ new_mask = []
470
+ for i in range(0, mask.shape[0], bs):
471
+ mask_bs = mask[i : i + bs]
472
+ mask_bs = self.vae.encode(mask_bs)[0]
473
+ mask_bs = mask_bs.mode()
474
+ new_mask.append(mask_bs)
475
+ mask = torch.cat(new_mask, dim = 0)
476
+ mask = mask * self.vae.config.scaling_factor
477
+
478
+ if masked_image is not None:
479
+ if self.transformer.config.add_noise_in_inpaint_model:
480
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
481
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
482
+ bs = 1
483
+ new_mask_pixel_values = []
484
+ for i in range(0, masked_image.shape[0], bs):
485
+ mask_pixel_values_bs = masked_image[i : i + bs]
486
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
487
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
488
+ new_mask_pixel_values.append(mask_pixel_values_bs)
489
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
490
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
491
+ else:
492
+ masked_image_latents = None
493
+
494
+ return mask, masked_image_latents
495
+
496
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
497
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
498
+ latents = 1 / self.vae.config.scaling_factor * latents
499
+
500
+ frames = self.vae.decode(latents).sample
501
+ frames = (frames / 2 + 0.5).clamp(0, 1)
502
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
503
+ frames = frames.cpu().float().numpy()
504
+ return frames
505
+
506
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
507
+ def prepare_extra_step_kwargs(self, generator, eta):
508
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
509
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
510
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
511
+ # and should be between [0, 1]
512
+
513
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
514
+ extra_step_kwargs = {}
515
+ if accepts_eta:
516
+ extra_step_kwargs["eta"] = eta
517
+
518
+ # check if the scheduler accepts generator
519
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
520
+ if accepts_generator:
521
+ extra_step_kwargs["generator"] = generator
522
+ return extra_step_kwargs
523
+
524
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
525
+ def check_inputs(
526
+ self,
527
+ prompt,
528
+ height,
529
+ width,
530
+ negative_prompt,
531
+ callback_on_step_end_tensor_inputs,
532
+ prompt_embeds=None,
533
+ negative_prompt_embeds=None,
534
+ ):
535
+ if height % 8 != 0 or width % 8 != 0:
536
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
537
+
538
+ if callback_on_step_end_tensor_inputs is not None and not all(
539
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
540
+ ):
541
+ raise ValueError(
542
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
543
+ )
544
+ if prompt is not None and prompt_embeds is not None:
545
+ raise ValueError(
546
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
547
+ " only forward one of the two."
548
+ )
549
+ elif prompt is None and prompt_embeds is None:
550
+ raise ValueError(
551
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
552
+ )
553
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
554
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
555
+
556
+ if prompt is not None and negative_prompt_embeds is not None:
557
+ raise ValueError(
558
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
559
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
560
+ )
561
+
562
+ if negative_prompt is not None and negative_prompt_embeds is not None:
563
+ raise ValueError(
564
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
565
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
566
+ )
567
+
568
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
569
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
570
+ raise ValueError(
571
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
572
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
573
+ f" {negative_prompt_embeds.shape}."
574
+ )
575
+
576
+ def fuse_qkv_projections(self) -> None:
577
+ r"""Enables fused QKV projections."""
578
+ self.fusing_transformer = True
579
+ self.transformer.fuse_qkv_projections()
580
+
581
+ def unfuse_qkv_projections(self) -> None:
582
+ r"""Disable QKV projection fusion if enabled."""
583
+ if not self.fusing_transformer:
584
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
585
+ else:
586
+ self.transformer.unfuse_qkv_projections()
587
+ self.fusing_transformer = False
588
+
589
+ def _prepare_rotary_positional_embeddings(
590
+ self,
591
+ height: int,
592
+ width: int,
593
+ num_frames: int,
594
+ device: torch.device,
595
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
596
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
597
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
598
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
599
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
600
+
601
+ grid_crops_coords = get_resize_crop_region_for_grid(
602
+ (grid_height, grid_width), base_size_width, base_size_height
603
+ )
604
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
605
+ embed_dim=self.transformer.config.attention_head_dim,
606
+ crops_coords=grid_crops_coords,
607
+ grid_size=(grid_height, grid_width),
608
+ temporal_size=num_frames,
609
+ use_real=True,
610
+ )
611
+
612
+ freqs_cos = freqs_cos.to(device=device)
613
+ freqs_sin = freqs_sin.to(device=device)
614
+ return freqs_cos, freqs_sin
615
+
616
+ @property
617
+ def guidance_scale(self):
618
+ return self._guidance_scale
619
+
620
+ @property
621
+ def num_timesteps(self):
622
+ return self._num_timesteps
623
+
624
+ @property
625
+ def interrupt(self):
626
+ return self._interrupt
627
+
628
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
629
+ def get_timesteps(self, num_inference_steps, strength, device):
630
+ # get the original timestep using init_timestep
631
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
632
+
633
+ t_start = max(num_inference_steps - init_timestep, 0)
634
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
635
+
636
+ return timesteps, num_inference_steps - t_start
637
+
638
+ @torch.no_grad()
639
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
640
+ def __call__(
641
+ self,
642
+ prompt: Optional[Union[str, List[str]]] = None,
643
+ negative_prompt: Optional[Union[str, List[str]]] = None,
644
+ height: int = 480,
645
+ width: int = 720,
646
+ video: Union[torch.FloatTensor] = None,
647
+ mask_video: Union[torch.FloatTensor] = None,
648
+ masked_video_latents: Union[torch.FloatTensor] = None,
649
+ num_frames: int = 49,
650
+ num_inference_steps: int = 50,
651
+ timesteps: Optional[List[int]] = None,
652
+ guidance_scale: float = 6,
653
+ use_dynamic_cfg: bool = False,
654
+ num_videos_per_prompt: int = 1,
655
+ eta: float = 0.0,
656
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
657
+ latents: Optional[torch.FloatTensor] = None,
658
+ prompt_embeds: Optional[torch.FloatTensor] = None,
659
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
660
+ output_type: str = "numpy",
661
+ return_dict: bool = False,
662
+ callback_on_step_end: Optional[
663
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
664
+ ] = None,
665
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
666
+ max_sequence_length: int = 226,
667
+ strength: float = 1,
668
+ noise_aug_strength: float = 0.0563,
669
+ comfyui_progressbar: bool = False,
670
+ ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
671
+ """
672
+ Function invoked when calling the pipeline for generation.
673
+
674
+ Args:
675
+ prompt (`str` or `List[str]`, *optional*):
676
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
677
+ instead.
678
+ negative_prompt (`str` or `List[str]`, *optional*):
679
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
680
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
681
+ less than `1`).
682
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
683
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
684
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
685
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
686
+ num_frames (`int`, defaults to `48`):
687
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
688
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
689
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
690
+ needs to be satisfied is that of divisibility mentioned above.
691
+ num_inference_steps (`int`, *optional*, defaults to 50):
692
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
693
+ expense of slower inference.
694
+ timesteps (`List[int]`, *optional*):
695
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
696
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
697
+ passed will be used. Must be in descending order.
698
+ guidance_scale (`float`, *optional*, defaults to 7.0):
699
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
700
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
701
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
702
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
703
+ usually at the expense of lower image quality.
704
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
705
+ The number of videos to generate per prompt.
706
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
707
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
708
+ to make generation deterministic.
709
+ latents (`torch.FloatTensor`, *optional*):
710
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
711
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
712
+ tensor will ge generated by sampling using the supplied random `generator`.
713
+ prompt_embeds (`torch.FloatTensor`, *optional*):
714
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
715
+ provided, text embeddings will be generated from `prompt` input argument.
716
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
717
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
718
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
719
+ argument.
720
+ output_type (`str`, *optional*, defaults to `"pil"`):
721
+ The output format of the generate image. Choose between
722
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
723
+ return_dict (`bool`, *optional*, defaults to `True`):
724
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
725
+ of a plain tuple.
726
+ callback_on_step_end (`Callable`, *optional*):
727
+ A function that calls at the end of each denoising steps during the inference. The function is called
728
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
729
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
730
+ `callback_on_step_end_tensor_inputs`.
731
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
732
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
733
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
734
+ `._callback_tensor_inputs` attribute of your pipeline class.
735
+ max_sequence_length (`int`, defaults to `226`):
736
+ Maximum sequence length in encoded prompt. Must be consistent with
737
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
738
+
739
+ Examples:
740
+
741
+ Returns:
742
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
743
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
744
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
745
+ """
746
+
747
+ if num_frames > 49:
748
+ raise ValueError(
749
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
750
+ )
751
+
752
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
753
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
754
+
755
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
756
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
757
+ num_videos_per_prompt = 1
758
+
759
+ # 1. Check inputs. Raise error if not correct
760
+ self.check_inputs(
761
+ prompt,
762
+ height,
763
+ width,
764
+ negative_prompt,
765
+ callback_on_step_end_tensor_inputs,
766
+ prompt_embeds,
767
+ negative_prompt_embeds,
768
+ )
769
+ self._guidance_scale = guidance_scale
770
+ self._interrupt = False
771
+
772
+ # 2. Default call parameters
773
+ if prompt is not None and isinstance(prompt, str):
774
+ batch_size = 1
775
+ elif prompt is not None and isinstance(prompt, list):
776
+ batch_size = len(prompt)
777
+ else:
778
+ batch_size = prompt_embeds.shape[0]
779
+
780
+ device = self._execution_device
781
+
782
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
783
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
784
+ # corresponds to doing no classifier free guidance.
785
+ do_classifier_free_guidance = guidance_scale > 1.0
786
+
787
+ # 3. Encode input prompt
788
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
789
+ prompt,
790
+ negative_prompt,
791
+ do_classifier_free_guidance,
792
+ num_videos_per_prompt=num_videos_per_prompt,
793
+ prompt_embeds=prompt_embeds,
794
+ negative_prompt_embeds=negative_prompt_embeds,
795
+ max_sequence_length=max_sequence_length,
796
+ device=device,
797
+ )
798
+ if do_classifier_free_guidance:
799
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
800
+
801
+ # 4. set timesteps
802
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
803
+ timesteps, num_inference_steps = self.get_timesteps(
804
+ num_inference_steps=num_inference_steps, strength=strength, device=device
805
+ )
806
+ self._num_timesteps = len(timesteps)
807
+ if comfyui_progressbar:
808
+ from comfy.utils import ProgressBar
809
+ pbar = ProgressBar(num_inference_steps + 2)
810
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
811
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
812
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
813
+ is_strength_max = strength == 1.0
814
+
815
+ # 5. Prepare latents.
816
+ if video is not None:
817
+ video_length = video.shape[2]
818
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
819
+ init_video = init_video.to(dtype=torch.float32)
820
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
821
+ else:
822
+ init_video = None
823
+
824
+ num_channels_latents = self.vae.config.latent_channels
825
+ num_channels_transformer = self.transformer.config.in_channels
826
+ return_image_latents = num_channels_transformer == num_channels_latents
827
+
828
+ latents_outputs = self.prepare_latents(
829
+ batch_size * num_videos_per_prompt,
830
+ num_channels_latents,
831
+ height,
832
+ width,
833
+ video_length,
834
+ prompt_embeds.dtype,
835
+ device,
836
+ generator,
837
+ latents,
838
+ video=init_video,
839
+ timestep=latent_timestep,
840
+ is_strength_max=is_strength_max,
841
+ return_noise=True,
842
+ return_video_latents=return_image_latents,
843
+ )
844
+ if return_image_latents:
845
+ latents, noise, image_latents = latents_outputs
846
+ else:
847
+ latents, noise = latents_outputs
848
+ if comfyui_progressbar:
849
+ pbar.update(1)
850
+
851
+ if mask_video is not None:
852
+ if (mask_video == 255).all():
853
+ mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
854
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
855
+
856
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
857
+ masked_video_latents_input = (
858
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
859
+ )
860
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
861
+ else:
862
+ # Prepare mask latent variables
863
+ video_length = video.shape[2]
864
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
865
+ mask_condition = mask_condition.to(dtype=torch.float32)
866
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
867
+
868
+ if num_channels_transformer != num_channels_latents:
869
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
870
+ if masked_video_latents is None:
871
+ masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
872
+ else:
873
+ masked_video = masked_video_latents
874
+
875
+ _, masked_video_latents = self.prepare_mask_latents(
876
+ None,
877
+ masked_video,
878
+ batch_size,
879
+ height,
880
+ width,
881
+ prompt_embeds.dtype,
882
+ device,
883
+ generator,
884
+ do_classifier_free_guidance,
885
+ noise_aug_strength=noise_aug_strength,
886
+ )
887
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
888
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
889
+
890
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
891
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
892
+
893
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
894
+ masked_video_latents_input = (
895
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
896
+ )
897
+
898
+ mask = rearrange(mask, "b c f h w -> b f c h w")
899
+ mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
900
+ masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
901
+
902
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
903
+ else:
904
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
905
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
906
+ mask = rearrange(mask, "b c f h w -> b f c h w")
907
+
908
+ inpaint_latents = None
909
+ else:
910
+ if num_channels_transformer != num_channels_latents:
911
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
912
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
913
+
914
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
915
+ masked_video_latents_input = (
916
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
917
+ )
918
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
919
+ else:
920
+ mask = torch.zeros_like(init_video[:, :1])
921
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
922
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
923
+ mask = rearrange(mask, "b c f h w -> b f c h w")
924
+
925
+ inpaint_latents = None
926
+ if comfyui_progressbar:
927
+ pbar.update(1)
928
+
929
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
930
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
931
+
932
+ # 7. Create rotary embeds if required
933
+ image_rotary_emb = (
934
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
935
+ if self.transformer.config.use_rotary_positional_embeddings
936
+ else None
937
+ )
938
+
939
+ # 8. Denoising loop
940
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
941
+
942
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
943
+ # for DPM-solver++
944
+ old_pred_original_sample = None
945
+ for i, t in enumerate(timesteps):
946
+ if self.interrupt:
947
+ continue
948
+
949
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
950
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
951
+
952
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
953
+ timestep = t.expand(latent_model_input.shape[0])
954
+
955
+ # predict noise model_output
956
+ noise_pred = self.transformer(
957
+ hidden_states=latent_model_input,
958
+ encoder_hidden_states=prompt_embeds,
959
+ timestep=timestep,
960
+ image_rotary_emb=image_rotary_emb,
961
+ return_dict=False,
962
+ inpaint_latents=inpaint_latents,
963
+ )[0]
964
+ noise_pred = noise_pred.float()
965
+
966
+ # perform guidance
967
+ if use_dynamic_cfg:
968
+ self._guidance_scale = 1 + guidance_scale * (
969
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
970
+ )
971
+ if do_classifier_free_guidance:
972
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
973
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
974
+
975
+ # compute the previous noisy sample x_t -> x_t-1
976
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
977
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
978
+ else:
979
+ latents, old_pred_original_sample = self.scheduler.step(
980
+ noise_pred,
981
+ old_pred_original_sample,
982
+ t,
983
+ timesteps[i - 1] if i > 0 else None,
984
+ latents,
985
+ **extra_step_kwargs,
986
+ return_dict=False,
987
+ )
988
+ latents = latents.to(prompt_embeds.dtype)
989
+
990
+ # call the callback, if provided
991
+ if callback_on_step_end is not None:
992
+ callback_kwargs = {}
993
+ for k in callback_on_step_end_tensor_inputs:
994
+ callback_kwargs[k] = locals()[k]
995
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
996
+
997
+ latents = callback_outputs.pop("latents", latents)
998
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
999
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1000
+
1001
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1002
+ progress_bar.update()
1003
+ if comfyui_progressbar:
1004
+ pbar.update(1)
1005
+
1006
+ if output_type == "numpy":
1007
+ video = self.decode_latents(latents)
1008
+ elif not output_type == "latent":
1009
+ video = self.decode_latents(latents)
1010
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1011
+ else:
1012
+ video = latents
1013
+
1014
+ # Offload all models
1015
+ self.maybe_free_model_hooks()
1016
+
1017
+ if not return_dict:
1018
+ video = torch.from_numpy(video)
1019
+
1020
+ return CogVideoX_Fun_PipelineOutput(videos=video)
cogvideox/ui/ui.py ADDED
@@ -0,0 +1,1614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import base64
4
+ import gc
5
+ import json
6
+ import os
7
+ import random
8
+ from datetime import datetime
9
+ from glob import glob
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import numpy as np
14
+ import pkg_resources
15
+ import requests
16
+ import torch
17
+ from diffusers import (AutoencoderKL, AutoencoderKLCogVideoX,
18
+ CogVideoXDDIMScheduler, DDIMScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
21
+ PNDMScheduler)
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from omegaconf import OmegaConf
24
+ from PIL import Image
25
+ from safetensors import safe_open
26
+ from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
27
+ T5EncoderModel, T5Tokenizer)
28
+
29
+ from cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
30
+ from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
31
+ from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
32
+ from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
33
+ from cogvideox.pipeline.pipeline_cogvideox_control import \
34
+ CogVideoX_Fun_Pipeline_Control
35
+ from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
36
+ CogVideoX_Fun_Pipeline_Inpaint
37
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
38
+ from cogvideox.utils.utils import (
39
+ get_image_to_video_latent, get_video_to_video_latent,
40
+ get_width_and_height_from_image_and_base_resolution, save_videos_grid)
41
+
42
+ scheduler_dict = {
43
+ "Euler": EulerDiscreteScheduler,
44
+ "Euler A": EulerAncestralDiscreteScheduler,
45
+ "DPM++": DPMSolverMultistepScheduler,
46
+ "PNDM": PNDMScheduler,
47
+ "DDIM_Cog": CogVideoXDDIMScheduler,
48
+ "DDIM_Origin": DDIMScheduler,
49
+ }
50
+
51
+ gradio_version = pkg_resources.get_distribution("gradio").version
52
+ gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
53
+
54
+ css = """
55
+ .toolbutton {
56
+ margin-buttom: 0em 0em 0em 0em;
57
+ max-width: 2.5em;
58
+ min-width: 2.5em !important;
59
+ height: 2.5em;
60
+ }
61
+ """
62
+
63
+ class CogVideoX_Fun_Controller:
64
+ def __init__(self, low_gpu_memory_mode, weight_dtype):
65
+ # config dirs
66
+ self.basedir = os.getcwd()
67
+ self.config_dir = os.path.join(self.basedir, "config")
68
+ self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
69
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
70
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
71
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
72
+ self.savedir_sample = os.path.join(self.savedir, "sample")
73
+ self.model_type = "Inpaint"
74
+ os.makedirs(self.savedir, exist_ok=True)
75
+
76
+ self.diffusion_transformer_list = []
77
+ self.motion_module_list = []
78
+ self.personalized_model_list = []
79
+
80
+ self.refresh_diffusion_transformer()
81
+ self.refresh_motion_module()
82
+ self.refresh_personalized_model()
83
+
84
+ # config models
85
+ self.tokenizer = None
86
+ self.text_encoder = None
87
+ self.vae = None
88
+ self.transformer = None
89
+ self.pipeline = None
90
+ self.motion_module_path = "none"
91
+ self.base_model_path = "none"
92
+ self.lora_model_path = "none"
93
+ self.low_gpu_memory_mode = low_gpu_memory_mode
94
+
95
+ self.weight_dtype = weight_dtype
96
+
97
+ def refresh_diffusion_transformer(self):
98
+ self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
99
+
100
+ def refresh_motion_module(self):
101
+ motion_module_list = sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
102
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
103
+
104
+ def refresh_personalized_model(self):
105
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
106
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
107
+
108
+ def update_model_type(self, model_type):
109
+ self.model_type = model_type
110
+
111
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
112
+ print("Update diffusion transformer")
113
+ if diffusion_transformer_dropdown == "none":
114
+ return gr.update()
115
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
116
+ diffusion_transformer_dropdown,
117
+ subfolder="vae",
118
+ ).to(self.weight_dtype)
119
+
120
+ # Get Transformer
121
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
122
+ diffusion_transformer_dropdown,
123
+ subfolder="transformer",
124
+ ).to(self.weight_dtype)
125
+
126
+ # Get pipeline
127
+ if self.model_type == "Inpaint":
128
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
129
+ self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
130
+ diffusion_transformer_dropdown,
131
+ vae=self.vae,
132
+ transformer=self.transformer,
133
+ scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
134
+ torch_dtype=self.weight_dtype
135
+ )
136
+ else:
137
+ self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
138
+ diffusion_transformer_dropdown,
139
+ vae=self.vae,
140
+ transformer=self.transformer,
141
+ scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
142
+ torch_dtype=self.weight_dtype
143
+ )
144
+ else:
145
+ self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
146
+ diffusion_transformer_dropdown,
147
+ vae=self.vae,
148
+ transformer=self.transformer,
149
+ scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
150
+ torch_dtype=self.weight_dtype
151
+ )
152
+
153
+ if self.low_gpu_memory_mode:
154
+ self.pipeline.enable_sequential_cpu_offload()
155
+ else:
156
+ self.pipeline.enable_model_cpu_offload()
157
+ print("Update diffusion transformer done")
158
+ return gr.update()
159
+
160
+ def update_base_model(self, base_model_dropdown):
161
+ self.base_model_path = base_model_dropdown
162
+ print("Update base model")
163
+ if base_model_dropdown == "none":
164
+ return gr.update()
165
+ if self.transformer is None:
166
+ gr.Info(f"Please select a pretrained model path.")
167
+ return gr.update(value=None)
168
+ else:
169
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
170
+ base_model_state_dict = {}
171
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
172
+ for key in f.keys():
173
+ base_model_state_dict[key] = f.get_tensor(key)
174
+ self.transformer.load_state_dict(base_model_state_dict, strict=False)
175
+ print("Update base done")
176
+ return gr.update()
177
+
178
+ def update_lora_model(self, lora_model_dropdown):
179
+ print("Update lora model")
180
+ if lora_model_dropdown == "none":
181
+ self.lora_model_path = "none"
182
+ return gr.update()
183
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
184
+ self.lora_model_path = lora_model_dropdown
185
+ return gr.update()
186
+
187
+ def generate(
188
+ self,
189
+ diffusion_transformer_dropdown,
190
+ base_model_dropdown,
191
+ lora_model_dropdown,
192
+ lora_alpha_slider,
193
+ prompt_textbox,
194
+ negative_prompt_textbox,
195
+ sampler_dropdown,
196
+ sample_step_slider,
197
+ resize_method,
198
+ width_slider,
199
+ height_slider,
200
+ base_resolution,
201
+ generation_method,
202
+ length_slider,
203
+ overlap_video_length,
204
+ partial_video_length,
205
+ cfg_scale_slider,
206
+ start_image,
207
+ end_image,
208
+ validation_video,
209
+ validation_video_mask,
210
+ control_video,
211
+ denoise_strength,
212
+ seed_textbox,
213
+ is_api = False,
214
+ ):
215
+ gc.collect()
216
+ torch.cuda.empty_cache()
217
+ torch.cuda.ipc_collect()
218
+
219
+ if self.transformer is None:
220
+ raise gr.Error(f"Please select a pretrained model path.")
221
+
222
+ if self.base_model_path != base_model_dropdown:
223
+ self.update_base_model(base_model_dropdown)
224
+
225
+ if self.lora_model_path != lora_model_dropdown:
226
+ print("Update lora model")
227
+ self.update_lora_model(lora_model_dropdown)
228
+
229
+ if control_video is not None and self.model_type == "Inpaint":
230
+ if is_api:
231
+ return "", f"If specifying the control video, please set the model_type == \"Control\". "
232
+ else:
233
+ raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
234
+
235
+ if control_video is None and self.model_type == "Control":
236
+ if is_api:
237
+ return "", f"If set the model_type == \"Control\", please specifying the control video. "
238
+ else:
239
+ raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
240
+
241
+ if resize_method == "Resize according to Reference":
242
+ if start_image is None and validation_video is None and control_video is None:
243
+ if is_api:
244
+ return "", f"Please upload an image when using \"Resize according to Reference\"."
245
+ else:
246
+ raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
247
+
248
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
249
+ if self.model_type == "Inpaint":
250
+ if validation_video is not None:
251
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
252
+ else:
253
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
254
+ else:
255
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
256
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
257
+ height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
258
+
259
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
260
+ if is_api:
261
+ return "", f"Please select an image to video pretrained model while using image to video."
262
+ else:
263
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
264
+
265
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation":
266
+ if is_api:
267
+ return "", f"Please select an image to video pretrained model while using long video generation."
268
+ else:
269
+ raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
270
+
271
+ if start_image is None and end_image is not None:
272
+ if is_api:
273
+ return "", f"If specifying the ending image of the video, please specify a starting image of the video."
274
+ else:
275
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
276
+
277
+ is_image = True if generation_method == "Image Generation" else False
278
+
279
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
280
+ if self.lora_model_path != "none":
281
+ # lora part
282
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
283
+
284
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
285
+ else: seed_textbox = np.random.randint(0, 1e10)
286
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
287
+
288
+ try:
289
+ if self.model_type == "Inpaint":
290
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
291
+ if generation_method == "Long Video Generation":
292
+ if validation_video is not None:
293
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
294
+ init_frames = 0
295
+ last_frames = init_frames + partial_video_length
296
+ while init_frames < length_slider:
297
+ if last_frames >= length_slider:
298
+ _partial_video_length = length_slider - init_frames
299
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
300
+
301
+ if _partial_video_length <= 0:
302
+ break
303
+ else:
304
+ _partial_video_length = partial_video_length
305
+
306
+ if last_frames >= length_slider:
307
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
308
+ else:
309
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
310
+
311
+ with torch.no_grad():
312
+ sample = self.pipeline(
313
+ prompt_textbox,
314
+ negative_prompt = negative_prompt_textbox,
315
+ num_inference_steps = sample_step_slider,
316
+ guidance_scale = cfg_scale_slider,
317
+ width = width_slider,
318
+ height = height_slider,
319
+ num_frames = _partial_video_length,
320
+ generator = generator,
321
+
322
+ video = input_video,
323
+ mask_video = input_video_mask,
324
+ strength = 1,
325
+ ).videos
326
+
327
+ if init_frames != 0:
328
+ mix_ratio = torch.from_numpy(
329
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
330
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
331
+
332
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
333
+ sample[:, :, :overlap_video_length] * mix_ratio
334
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
335
+
336
+ sample = new_sample
337
+ else:
338
+ new_sample = sample
339
+
340
+ if last_frames >= length_slider:
341
+ break
342
+
343
+ start_image = [
344
+ Image.fromarray(
345
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
346
+ ) for _index in range(-overlap_video_length, 0)
347
+ ]
348
+
349
+ init_frames = init_frames + _partial_video_length - overlap_video_length
350
+ last_frames = init_frames + _partial_video_length
351
+ else:
352
+ if validation_video is not None:
353
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
354
+ strength = denoise_strength
355
+ else:
356
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
357
+ strength = 1
358
+
359
+ sample = self.pipeline(
360
+ prompt_textbox,
361
+ negative_prompt = negative_prompt_textbox,
362
+ num_inference_steps = sample_step_slider,
363
+ guidance_scale = cfg_scale_slider,
364
+ width = width_slider,
365
+ height = height_slider,
366
+ num_frames = length_slider if not is_image else 1,
367
+ generator = generator,
368
+
369
+ video = input_video,
370
+ mask_video = input_video_mask,
371
+ strength = strength,
372
+ ).videos
373
+ else:
374
+ sample = self.pipeline(
375
+ prompt_textbox,
376
+ negative_prompt = negative_prompt_textbox,
377
+ num_inference_steps = sample_step_slider,
378
+ guidance_scale = cfg_scale_slider,
379
+ width = width_slider,
380
+ height = height_slider,
381
+ num_frames = length_slider if not is_image else 1,
382
+ generator = generator
383
+ ).videos
384
+ else:
385
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
386
+
387
+ sample = self.pipeline(
388
+ prompt_textbox,
389
+ negative_prompt = negative_prompt_textbox,
390
+ num_inference_steps = sample_step_slider,
391
+ guidance_scale = cfg_scale_slider,
392
+ width = width_slider,
393
+ height = height_slider,
394
+ num_frames = length_slider if not is_image else 1,
395
+ generator = generator,
396
+
397
+ control_video = input_video,
398
+ ).videos
399
+ except Exception as e:
400
+ gc.collect()
401
+ torch.cuda.empty_cache()
402
+ torch.cuda.ipc_collect()
403
+ if self.lora_model_path != "none":
404
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
405
+ if is_api:
406
+ return "", f"Error. error information is {str(e)}"
407
+ else:
408
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
409
+
410
+ gc.collect()
411
+ torch.cuda.empty_cache()
412
+ torch.cuda.ipc_collect()
413
+
414
+ # lora part
415
+ if self.lora_model_path != "none":
416
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
417
+
418
+ sample_config = {
419
+ "prompt": prompt_textbox,
420
+ "n_prompt": negative_prompt_textbox,
421
+ "sampler": sampler_dropdown,
422
+ "num_inference_steps": sample_step_slider,
423
+ "guidance_scale": cfg_scale_slider,
424
+ "width": width_slider,
425
+ "height": height_slider,
426
+ "video_length": length_slider,
427
+ "seed_textbox": seed_textbox
428
+ }
429
+ json_str = json.dumps(sample_config, indent=4)
430
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
431
+ f.write(json_str)
432
+ f.write("\n\n")
433
+
434
+ if not os.path.exists(self.savedir_sample):
435
+ os.makedirs(self.savedir_sample, exist_ok=True)
436
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
437
+ prefix = str(index).zfill(3)
438
+
439
+ gc.collect()
440
+ torch.cuda.empty_cache()
441
+ torch.cuda.ipc_collect()
442
+ if is_image or length_slider == 1:
443
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
444
+
445
+ image = sample[0, :, 0]
446
+ image = image.transpose(0, 1).transpose(1, 2)
447
+ image = (image * 255).numpy().astype(np.uint8)
448
+ image = Image.fromarray(image)
449
+ image.save(save_sample_path)
450
+
451
+ if is_api:
452
+ return save_sample_path, "Success"
453
+ else:
454
+ if gradio_version_is_above_4:
455
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
456
+ else:
457
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
458
+ else:
459
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
460
+ save_videos_grid(sample, save_sample_path, fps=8)
461
+
462
+ if is_api:
463
+ return save_sample_path, "Success"
464
+ else:
465
+ if gradio_version_is_above_4:
466
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
467
+ else:
468
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
469
+
470
+
471
+ def ui(low_gpu_memory_mode, weight_dtype):
472
+ controller = CogVideoX_Fun_Controller(low_gpu_memory_mode, weight_dtype)
473
+
474
+ with gr.Blocks(css=css) as demo:
475
+ gr.Markdown(
476
+ """
477
+ # CogVideoX-Fun:
478
+
479
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
480
+
481
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
482
+ """
483
+ )
484
+ with gr.Column(variant="panel"):
485
+ gr.Markdown(
486
+ """
487
+ ### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
488
+ """
489
+ )
490
+ with gr.Row():
491
+ model_type = gr.Dropdown(
492
+ label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
493
+ choices=["Inpaint", "Control"],
494
+ value="Inpaint",
495
+ interactive=True,
496
+ )
497
+
498
+ gr.Markdown(
499
+ """
500
+ ### 2. Model checkpoints (模型路径).
501
+ """
502
+ )
503
+ with gr.Row():
504
+ diffusion_transformer_dropdown = gr.Dropdown(
505
+ label="Pretrained Model Path (预训练模型路径)",
506
+ choices=controller.diffusion_transformer_list,
507
+ value="none",
508
+ interactive=True,
509
+ )
510
+ diffusion_transformer_dropdown.change(
511
+ fn=controller.update_diffusion_transformer,
512
+ inputs=[diffusion_transformer_dropdown],
513
+ outputs=[diffusion_transformer_dropdown]
514
+ )
515
+
516
+ diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
517
+ def refresh_diffusion_transformer():
518
+ controller.refresh_diffusion_transformer()
519
+ return gr.update(choices=controller.diffusion_transformer_list)
520
+ diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
521
+
522
+ with gr.Row():
523
+ base_model_dropdown = gr.Dropdown(
524
+ label="Select base Dreambooth model (选���基模型[非必需])",
525
+ choices=controller.personalized_model_list,
526
+ value="none",
527
+ interactive=True,
528
+ )
529
+
530
+ lora_model_dropdown = gr.Dropdown(
531
+ label="Select LoRA model (选择LoRA模型[非必需])",
532
+ choices=["none"] + controller.personalized_model_list,
533
+ value="none",
534
+ interactive=True,
535
+ )
536
+
537
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
538
+
539
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
540
+ def update_personalized_model():
541
+ controller.refresh_personalized_model()
542
+ return [
543
+ gr.update(choices=controller.personalized_model_list),
544
+ gr.update(choices=["none"] + controller.personalized_model_list)
545
+ ]
546
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
547
+
548
+ with gr.Column(variant="panel"):
549
+ gr.Markdown(
550
+ """
551
+ ### 3. Configs for Generation (生成参数配置).
552
+ """
553
+ )
554
+
555
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
556
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
557
+
558
+ with gr.Row():
559
+ with gr.Column():
560
+ with gr.Row():
561
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
562
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
563
+
564
+ resize_method = gr.Radio(
565
+ ["Generate by", "Resize according to Reference"],
566
+ value="Generate by",
567
+ show_label=False,
568
+ )
569
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1344, step=16)
570
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1344, step=16)
571
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
572
+
573
+ with gr.Group():
574
+ generation_method = gr.Radio(
575
+ ["Video Generation", "Image Generation", "Long Video Generation"],
576
+ value="Video Generation",
577
+ show_label=False,
578
+ )
579
+ with gr.Row():
580
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=1, maximum=49, step=4)
581
+ overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
582
+ partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
583
+
584
+ source_method = gr.Radio(
585
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
586
+ value="Text to Video (文本到视频)",
587
+ show_label=False,
588
+ )
589
+ with gr.Column(visible = False) as image_to_video_col:
590
+ start_image = gr.Image(
591
+ label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True,
592
+ elem_id="i2v_start", sources="upload", type="filepath",
593
+ )
594
+
595
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
596
+ def select_template(evt: gr.SelectData):
597
+ text = {
598
+ "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
599
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
600
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
601
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
602
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
603
+ }[template_gallery_path[evt.index]]
604
+ return template_gallery_path[evt.index], text
605
+
606
+ template_gallery = gr.Gallery(
607
+ template_gallery_path,
608
+ columns=5, rows=1,
609
+ height=140,
610
+ allow_preview=False,
611
+ container=False,
612
+ label="Template Examples",
613
+ )
614
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
615
+
616
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
617
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
618
+
619
+ with gr.Column(visible = False) as video_to_video_col:
620
+ with gr.Row():
621
+ validation_video = gr.Video(
622
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
623
+ elem_id="v2v", sources="upload",
624
+ )
625
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
626
+ gr.Markdown(
627
+ """
628
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
629
+ - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
630
+ """
631
+ )
632
+ validation_video_mask = gr.Image(
633
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
634
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
635
+ )
636
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
637
+
638
+ with gr.Column(visible = False) as control_video_col:
639
+ gr.Markdown(
640
+ """
641
+ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
642
+ """
643
+ )
644
+ control_video = gr.Video(
645
+ label="The control video (用于提供控制信号的video)", show_label=True,
646
+ elem_id="v2v_control", sources="upload",
647
+ )
648
+
649
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
650
+
651
+ with gr.Row():
652
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
653
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
654
+ seed_button.click(
655
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
656
+ inputs=[],
657
+ outputs=[seed_textbox]
658
+ )
659
+
660
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
661
+
662
+ with gr.Column():
663
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
664
+ result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
665
+ infer_progress = gr.Textbox(
666
+ label="Generation Info (生成信息)",
667
+ value="No task currently",
668
+ interactive=False
669
+ )
670
+
671
+ model_type.change(
672
+ fn=controller.update_model_type,
673
+ inputs=[model_type],
674
+ outputs=[]
675
+ )
676
+
677
+ def upload_generation_method(generation_method):
678
+ if generation_method == "Video Generation":
679
+ return [gr.update(visible=True, maximum=49, value=49), gr.update(visible=False), gr.update(visible=False)]
680
+ elif generation_method == "Image Generation":
681
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
682
+ else:
683
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
684
+ generation_method.change(
685
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
686
+ )
687
+
688
+ def upload_source_method(source_method):
689
+ if source_method == "Text to Video (文本到视频)":
690
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
691
+ elif source_method == "Image to Video (图片到视频)":
692
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
693
+ elif source_method == "Video to Video (视频到视频)":
694
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
695
+ else:
696
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
697
+ source_method.change(
698
+ upload_source_method, source_method, [
699
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
700
+ validation_video, validation_video_mask, control_video
701
+ ]
702
+ )
703
+
704
+ def upload_resize_method(resize_method):
705
+ if resize_method == "Generate by":
706
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
707
+ else:
708
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
709
+ resize_method.change(
710
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
711
+ )
712
+
713
+ generate_button.click(
714
+ fn=controller.generate,
715
+ inputs=[
716
+ diffusion_transformer_dropdown,
717
+ base_model_dropdown,
718
+ lora_model_dropdown,
719
+ lora_alpha_slider,
720
+ prompt_textbox,
721
+ negative_prompt_textbox,
722
+ sampler_dropdown,
723
+ sample_step_slider,
724
+ resize_method,
725
+ width_slider,
726
+ height_slider,
727
+ base_resolution,
728
+ generation_method,
729
+ length_slider,
730
+ overlap_video_length,
731
+ partial_video_length,
732
+ cfg_scale_slider,
733
+ start_image,
734
+ end_image,
735
+ validation_video,
736
+ validation_video_mask,
737
+ control_video,
738
+ denoise_strength,
739
+ seed_textbox,
740
+ ],
741
+ outputs=[result_image, result_video, infer_progress]
742
+ )
743
+ return demo, controller
744
+
745
+
746
+ class CogVideoX_Fun_Controller_Modelscope:
747
+ def __init__(self, model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
748
+ # Basic dir
749
+ self.basedir = os.getcwd()
750
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
751
+ self.lora_model_path = "none"
752
+ self.savedir_sample = savedir_sample
753
+ self.refresh_personalized_model()
754
+ os.makedirs(self.savedir_sample, exist_ok=True)
755
+
756
+ # model path
757
+ self.model_type = model_type
758
+ self.weight_dtype = weight_dtype
759
+
760
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
761
+ model_name,
762
+ subfolder="vae",
763
+ ).to(self.weight_dtype)
764
+
765
+ # Get Transformer
766
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
767
+ model_name,
768
+ subfolder="transformer",
769
+ ).to(self.weight_dtype)
770
+
771
+ # Get pipeline
772
+ if model_type == "Inpaint":
773
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
774
+ self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
775
+ model_name,
776
+ vae=self.vae,
777
+ transformer=self.transformer,
778
+ scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
779
+ torch_dtype=self.weight_dtype
780
+ )
781
+ else:
782
+ self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
783
+ model_name,
784
+ vae=self.vae,
785
+ transformer=self.transformer,
786
+ scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
787
+ torch_dtype=self.weight_dtype
788
+ )
789
+ else:
790
+ self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
791
+ model_name,
792
+ vae=self.vae,
793
+ transformer=self.transformer,
794
+ scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
795
+ torch_dtype=self.weight_dtype
796
+ )
797
+
798
+ if low_gpu_memory_mode:
799
+ self.pipeline.enable_sequential_cpu_offload()
800
+ else:
801
+ self.pipeline.enable_model_cpu_offload()
802
+ print("Update diffusion transformer done")
803
+
804
+
805
+ def refresh_personalized_model(self):
806
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
807
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
808
+
809
+
810
+ def update_lora_model(self, lora_model_dropdown):
811
+ print("Update lora model")
812
+ if lora_model_dropdown == "none":
813
+ self.lora_model_path = "none"
814
+ return gr.update()
815
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
816
+ self.lora_model_path = lora_model_dropdown
817
+ return gr.update()
818
+
819
+
820
+ def generate(
821
+ self,
822
+ diffusion_transformer_dropdown,
823
+ base_model_dropdown,
824
+ lora_model_dropdown,
825
+ lora_alpha_slider,
826
+ prompt_textbox,
827
+ negative_prompt_textbox,
828
+ sampler_dropdown,
829
+ sample_step_slider,
830
+ resize_method,
831
+ width_slider,
832
+ height_slider,
833
+ base_resolution,
834
+ generation_method,
835
+ length_slider,
836
+ overlap_video_length,
837
+ partial_video_length,
838
+ cfg_scale_slider,
839
+ start_image,
840
+ end_image,
841
+ validation_video,
842
+ validation_video_mask,
843
+ control_video,
844
+ denoise_strength,
845
+ seed_textbox,
846
+ is_api = False,
847
+ ):
848
+ gc.collect()
849
+ torch.cuda.empty_cache()
850
+ torch.cuda.ipc_collect()
851
+
852
+ if self.transformer is None:
853
+ raise gr.Error(f"Please select a pretrained model path.")
854
+
855
+ if self.lora_model_path != lora_model_dropdown:
856
+ print("Update lora model")
857
+ self.update_lora_model(lora_model_dropdown)
858
+
859
+ if control_video is not None and self.model_type == "Inpaint":
860
+ if is_api:
861
+ return "", f"If specifying the control video, please set the model_type == \"Control\". "
862
+ else:
863
+ raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
864
+
865
+ if control_video is None and self.model_type == "Control":
866
+ if is_api:
867
+ return "", f"If set the model_type == \"Control\", please specifying the control video. "
868
+ else:
869
+ raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
870
+
871
+ if resize_method == "Resize according to Reference":
872
+ if start_image is None and validation_video is None and control_video is None:
873
+ if is_api:
874
+ return "", f"Please upload an image when using \"Resize according to Reference\"."
875
+ else:
876
+ raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
877
+
878
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
879
+ if self.model_type == "Inpaint":
880
+ if validation_video is not None:
881
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
882
+ else:
883
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
884
+ else:
885
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
886
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
887
+ height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
888
+
889
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
890
+ if is_api:
891
+ return "", f"Please select an image to video pretrained model while using image to video."
892
+ else:
893
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
894
+
895
+ if start_image is None and end_image is not None:
896
+ if is_api:
897
+ return "", f"If specifying the ending image of the video, please specify a starting image of the video."
898
+ else:
899
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
900
+
901
+ is_image = True if generation_method == "Image Generation" else False
902
+
903
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
904
+ if self.lora_model_path != "none":
905
+ # lora part
906
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
907
+
908
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
909
+ else: seed_textbox = np.random.randint(0, 1e10)
910
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
911
+
912
+ try:
913
+ if self.model_type == "Inpaint":
914
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
915
+ if validation_video is not None:
916
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
917
+ strength = denoise_strength
918
+ else:
919
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
920
+ strength = 1
921
+
922
+ sample = self.pipeline(
923
+ prompt_textbox,
924
+ negative_prompt = negative_prompt_textbox,
925
+ num_inference_steps = sample_step_slider,
926
+ guidance_scale = cfg_scale_slider,
927
+ width = width_slider,
928
+ height = height_slider,
929
+ num_frames = length_slider if not is_image else 1,
930
+ generator = generator,
931
+
932
+ video = input_video,
933
+ mask_video = input_video_mask,
934
+ strength = strength,
935
+ ).videos
936
+ else:
937
+ sample = self.pipeline(
938
+ prompt_textbox,
939
+ negative_prompt = negative_prompt_textbox,
940
+ num_inference_steps = sample_step_slider,
941
+ guidance_scale = cfg_scale_slider,
942
+ width = width_slider,
943
+ height = height_slider,
944
+ num_frames = length_slider if not is_image else 1,
945
+ generator = generator
946
+ ).videos
947
+ else:
948
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
949
+
950
+ sample = self.pipeline(
951
+ prompt_textbox,
952
+ negative_prompt = negative_prompt_textbox,
953
+ num_inference_steps = sample_step_slider,
954
+ guidance_scale = cfg_scale_slider,
955
+ width = width_slider,
956
+ height = height_slider,
957
+ num_frames = length_slider if not is_image else 1,
958
+ generator = generator,
959
+
960
+ control_video = input_video,
961
+ ).videos
962
+ except Exception as e:
963
+ gc.collect()
964
+ torch.cuda.empty_cache()
965
+ torch.cuda.ipc_collect()
966
+ if self.lora_model_path != "none":
967
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
968
+ if is_api:
969
+ return "", f"Error. error information is {str(e)}"
970
+ else:
971
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
972
+
973
+ gc.collect()
974
+ torch.cuda.empty_cache()
975
+ torch.cuda.ipc_collect()
976
+
977
+ # lora part
978
+ if self.lora_model_path != "none":
979
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
980
+
981
+ if not os.path.exists(self.savedir_sample):
982
+ os.makedirs(self.savedir_sample, exist_ok=True)
983
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
984
+ prefix = str(index).zfill(3)
985
+
986
+ gc.collect()
987
+ torch.cuda.empty_cache()
988
+ torch.cuda.ipc_collect()
989
+ if is_image or length_slider == 1:
990
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
991
+
992
+ image = sample[0, :, 0]
993
+ image = image.transpose(0, 1).transpose(1, 2)
994
+ image = (image * 255).numpy().astype(np.uint8)
995
+ image = Image.fromarray(image)
996
+ image.save(save_sample_path)
997
+ if is_api:
998
+ return save_sample_path, "Success"
999
+ else:
1000
+ if gradio_version_is_above_4:
1001
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
1002
+ else:
1003
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
1004
+ else:
1005
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
1006
+ save_videos_grid(sample, save_sample_path, fps=8)
1007
+ if is_api:
1008
+ return save_sample_path, "Success"
1009
+ else:
1010
+ if gradio_version_is_above_4:
1011
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
1012
+ else:
1013
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
1014
+
1015
+
1016
+ def ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
1017
+ controller = CogVideoX_Fun_Controller_Modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
1018
+
1019
+ with gr.Blocks(css=css) as demo:
1020
+ gr.Markdown(
1021
+ """
1022
+ # CogVideoX-Fun
1023
+
1024
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
1025
+
1026
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
1027
+ """
1028
+ )
1029
+ with gr.Column(variant="panel"):
1030
+ gr.Markdown(
1031
+ """
1032
+ ### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
1033
+ """
1034
+ )
1035
+ with gr.Row():
1036
+ model_type = gr.Dropdown(
1037
+ label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
1038
+ choices=[model_type],
1039
+ value=model_type,
1040
+ interactive=False,
1041
+ )
1042
+
1043
+ gr.Markdown(
1044
+ """
1045
+ ### 2. Model checkpoints (模型路径).
1046
+ """
1047
+ )
1048
+ with gr.Row():
1049
+ diffusion_transformer_dropdown = gr.Dropdown(
1050
+ label="Pretrained Model Path (预训练模型路径)",
1051
+ choices=[model_name],
1052
+ value=model_name,
1053
+ interactive=False,
1054
+ )
1055
+ with gr.Row():
1056
+ base_model_dropdown = gr.Dropdown(
1057
+ label="Select base Dreambooth model (选择基模型[非必需])",
1058
+ choices=["none"],
1059
+ value="none",
1060
+ interactive=False,
1061
+ visible=False
1062
+ )
1063
+ with gr.Column(visible=False):
1064
+ gr.Markdown(
1065
+ """
1066
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
1067
+ """
1068
+ )
1069
+ with gr.Row():
1070
+ lora_model_dropdown = gr.Dropdown(
1071
+ label="Select LoRA model",
1072
+ choices=["none"],
1073
+ value="none",
1074
+ interactive=True,
1075
+ )
1076
+
1077
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
1078
+
1079
+ with gr.Column(variant="panel"):
1080
+ gr.Markdown(
1081
+ """
1082
+ ### 3. Configs for Generation (生成参数配置).
1083
+ """
1084
+ )
1085
+
1086
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1087
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
1088
+
1089
+ with gr.Row():
1090
+ with gr.Column():
1091
+ with gr.Row():
1092
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
1093
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
1094
+
1095
+ resize_method = gr.Radio(
1096
+ ["Generate by", "Resize according to Reference"],
1097
+ value="Generate by",
1098
+ show_label=False,
1099
+ )
1100
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
1101
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
1102
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
1103
+
1104
+ with gr.Group():
1105
+ generation_method = gr.Radio(
1106
+ ["Video Generation", "Image Generation"],
1107
+ value="Video Generation",
1108
+ show_label=False,
1109
+ visible=True,
1110
+ )
1111
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=5, maximum=49, step=4)
1112
+ overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
1113
+ partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
1114
+
1115
+ source_method = gr.Radio(
1116
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
1117
+ value="Text to Video (文本到视频)",
1118
+ show_label=False,
1119
+ )
1120
+ with gr.Column(visible = False) as image_to_video_col:
1121
+ with gr.Row():
1122
+ start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
1123
+
1124
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1125
+ def select_template(evt: gr.SelectData):
1126
+ text = {
1127
+ "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1128
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1129
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1130
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1131
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1132
+ }[template_gallery_path[evt.index]]
1133
+ return template_gallery_path[evt.index], text
1134
+
1135
+ template_gallery = gr.Gallery(
1136
+ template_gallery_path,
1137
+ columns=5, rows=1,
1138
+ height=140,
1139
+ allow_preview=False,
1140
+ container=False,
1141
+ label="Template Examples",
1142
+ )
1143
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
1144
+
1145
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
1146
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
1147
+
1148
+ with gr.Column(visible = False) as video_to_video_col:
1149
+ with gr.Row():
1150
+ validation_video = gr.Video(
1151
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
1152
+ elem_id="v2v", sources="upload",
1153
+ )
1154
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
1155
+ gr.Markdown(
1156
+ """
1157
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
1158
+ - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
1159
+ """
1160
+ )
1161
+ validation_video_mask = gr.Image(
1162
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
1163
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
1164
+ )
1165
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
1166
+
1167
+ with gr.Column(visible = False) as control_video_col:
1168
+ gr.Markdown(
1169
+ """
1170
+ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
1171
+ """
1172
+ )
1173
+ control_video = gr.Video(
1174
+ label="The control video (用于提供控制信号的video)", show_label=True,
1175
+ elem_id="v2v_control", sources="upload",
1176
+ )
1177
+
1178
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
1179
+
1180
+ with gr.Row():
1181
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
1182
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
1183
+ seed_button.click(
1184
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
1185
+ inputs=[],
1186
+ outputs=[seed_textbox]
1187
+ )
1188
+
1189
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
1190
+
1191
+ with gr.Column():
1192
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
1193
+ result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
1194
+ infer_progress = gr.Textbox(
1195
+ label="Generation Info (生成信息)",
1196
+ value="No task currently",
1197
+ interactive=False
1198
+ )
1199
+
1200
+ def upload_generation_method(generation_method):
1201
+ if generation_method == "Video Generation":
1202
+ return gr.update(visible=True, minimum=8, maximum=49, value=49, interactive=True)
1203
+ elif generation_method == "Image Generation":
1204
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1205
+ generation_method.change(
1206
+ upload_generation_method, generation_method, [length_slider]
1207
+ )
1208
+
1209
+ def upload_source_method(source_method):
1210
+ if source_method == "Text to Video (文本到视频)":
1211
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1212
+ elif source_method == "Image to Video (图片到视频)":
1213
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1214
+ elif source_method == "Video to Video (视频到视频)":
1215
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
1216
+ else:
1217
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
1218
+ source_method.change(
1219
+ upload_source_method, source_method, [
1220
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
1221
+ validation_video, validation_video_mask, control_video
1222
+ ]
1223
+ )
1224
+
1225
+ def upload_resize_method(resize_method):
1226
+ if resize_method == "Generate by":
1227
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
1228
+ else:
1229
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
1230
+ resize_method.change(
1231
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
1232
+ )
1233
+
1234
+ generate_button.click(
1235
+ fn=controller.generate,
1236
+ inputs=[
1237
+ diffusion_transformer_dropdown,
1238
+ base_model_dropdown,
1239
+ lora_model_dropdown,
1240
+ lora_alpha_slider,
1241
+ prompt_textbox,
1242
+ negative_prompt_textbox,
1243
+ sampler_dropdown,
1244
+ sample_step_slider,
1245
+ resize_method,
1246
+ width_slider,
1247
+ height_slider,
1248
+ base_resolution,
1249
+ generation_method,
1250
+ length_slider,
1251
+ overlap_video_length,
1252
+ partial_video_length,
1253
+ cfg_scale_slider,
1254
+ start_image,
1255
+ end_image,
1256
+ validation_video,
1257
+ validation_video_mask,
1258
+ control_video,
1259
+ denoise_strength,
1260
+ seed_textbox,
1261
+ ],
1262
+ outputs=[result_image, result_video, infer_progress]
1263
+ )
1264
+ return demo, controller
1265
+
1266
+
1267
+ def post_eas(
1268
+ diffusion_transformer_dropdown,
1269
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
1270
+ prompt_textbox, negative_prompt_textbox,
1271
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1272
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
1273
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox,
1274
+ ):
1275
+ if start_image is not None:
1276
+ with open(start_image, 'rb') as file:
1277
+ file_content = file.read()
1278
+ start_image_encoded_content = base64.b64encode(file_content)
1279
+ start_image = start_image_encoded_content.decode('utf-8')
1280
+
1281
+ if end_image is not None:
1282
+ with open(end_image, 'rb') as file:
1283
+ file_content = file.read()
1284
+ end_image_encoded_content = base64.b64encode(file_content)
1285
+ end_image = end_image_encoded_content.decode('utf-8')
1286
+
1287
+ if validation_video is not None:
1288
+ with open(validation_video, 'rb') as file:
1289
+ file_content = file.read()
1290
+ validation_video_encoded_content = base64.b64encode(file_content)
1291
+ validation_video = validation_video_encoded_content.decode('utf-8')
1292
+
1293
+ if validation_video_mask is not None:
1294
+ with open(validation_video_mask, 'rb') as file:
1295
+ file_content = file.read()
1296
+ validation_video_mask_encoded_content = base64.b64encode(file_content)
1297
+ validation_video_mask = validation_video_mask_encoded_content.decode('utf-8')
1298
+
1299
+ datas = {
1300
+ "base_model_path": base_model_dropdown,
1301
+ "lora_model_path": lora_model_dropdown,
1302
+ "lora_alpha_slider": lora_alpha_slider,
1303
+ "prompt_textbox": prompt_textbox,
1304
+ "negative_prompt_textbox": negative_prompt_textbox,
1305
+ "sampler_dropdown": sampler_dropdown,
1306
+ "sample_step_slider": sample_step_slider,
1307
+ "resize_method": resize_method,
1308
+ "width_slider": width_slider,
1309
+ "height_slider": height_slider,
1310
+ "base_resolution": base_resolution,
1311
+ "generation_method": generation_method,
1312
+ "length_slider": length_slider,
1313
+ "cfg_scale_slider": cfg_scale_slider,
1314
+ "start_image": start_image,
1315
+ "end_image": end_image,
1316
+ "validation_video": validation_video,
1317
+ "validation_video_mask": validation_video_mask,
1318
+ "denoise_strength": denoise_strength,
1319
+ "seed_textbox": seed_textbox,
1320
+ }
1321
+
1322
+ session = requests.session()
1323
+ session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
1324
+
1325
+ response = session.post(url=f'{os.environ.get("EAS_URL")}/cogvideox_fun/infer_forward', json=datas, timeout=300)
1326
+
1327
+ outputs = response.json()
1328
+ return outputs
1329
+
1330
+
1331
+ class CogVideoX_Fun_Controller_EAS:
1332
+ def __init__(self, model_name, savedir_sample):
1333
+ self.savedir_sample = savedir_sample
1334
+ os.makedirs(self.savedir_sample, exist_ok=True)
1335
+
1336
+ def generate(
1337
+ self,
1338
+ diffusion_transformer_dropdown,
1339
+ base_model_dropdown,
1340
+ lora_model_dropdown,
1341
+ lora_alpha_slider,
1342
+ prompt_textbox,
1343
+ negative_prompt_textbox,
1344
+ sampler_dropdown,
1345
+ sample_step_slider,
1346
+ resize_method,
1347
+ width_slider,
1348
+ height_slider,
1349
+ base_resolution,
1350
+ generation_method,
1351
+ length_slider,
1352
+ cfg_scale_slider,
1353
+ start_image,
1354
+ end_image,
1355
+ validation_video,
1356
+ validation_video_mask,
1357
+ denoise_strength,
1358
+ seed_textbox
1359
+ ):
1360
+ is_image = True if generation_method == "Image Generation" else False
1361
+
1362
+ outputs = post_eas(
1363
+ diffusion_transformer_dropdown,
1364
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
1365
+ prompt_textbox, negative_prompt_textbox,
1366
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1367
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
1368
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength,
1369
+ seed_textbox
1370
+ )
1371
+ try:
1372
+ base64_encoding = outputs["base64_encoding"]
1373
+ except:
1374
+ return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
1375
+
1376
+ decoded_data = base64.b64decode(base64_encoding)
1377
+
1378
+ if not os.path.exists(self.savedir_sample):
1379
+ os.makedirs(self.savedir_sample, exist_ok=True)
1380
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
1381
+ prefix = str(index).zfill(3)
1382
+
1383
+ if is_image or length_slider == 1:
1384
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
1385
+ with open(save_sample_path, "wb") as file:
1386
+ file.write(decoded_data)
1387
+ if gradio_version_is_above_4:
1388
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
1389
+ else:
1390
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
1391
+ else:
1392
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
1393
+ with open(save_sample_path, "wb") as file:
1394
+ file.write(decoded_data)
1395
+ if gradio_version_is_above_4:
1396
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
1397
+ else:
1398
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
1399
+
1400
+
1401
+ def ui_eas(model_name, savedir_sample):
1402
+ controller = CogVideoX_Fun_Controller_EAS(model_name, savedir_sample)
1403
+
1404
+ with gr.Blocks(css=css) as demo:
1405
+ gr.Markdown(
1406
+ """
1407
+ # CogVideoX-Fun
1408
+
1409
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
1410
+
1411
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
1412
+ """
1413
+ )
1414
+ with gr.Column(variant="panel"):
1415
+ gr.Markdown(
1416
+ """
1417
+ ### 1. Model checkpoints (模型路径).
1418
+ """
1419
+ )
1420
+ with gr.Row():
1421
+ diffusion_transformer_dropdown = gr.Dropdown(
1422
+ label="Pretrained Model Path",
1423
+ choices=[model_name],
1424
+ value=model_name,
1425
+ interactive=False,
1426
+ )
1427
+ with gr.Row():
1428
+ base_model_dropdown = gr.Dropdown(
1429
+ label="Select base Dreambooth model",
1430
+ choices=["none"],
1431
+ value="none",
1432
+ interactive=False,
1433
+ visible=False
1434
+ )
1435
+ with gr.Column(visible=False):
1436
+ gr.Markdown(
1437
+ """
1438
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
1439
+ """
1440
+ )
1441
+ with gr.Row():
1442
+ lora_model_dropdown = gr.Dropdown(
1443
+ label="Select LoRA model",
1444
+ choices=["none"],
1445
+ value="none",
1446
+ interactive=True,
1447
+ )
1448
+
1449
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
1450
+
1451
+ with gr.Column(variant="panel"):
1452
+ gr.Markdown(
1453
+ """
1454
+ ### 2. Configs for Generation.
1455
+ """
1456
+ )
1457
+
1458
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1459
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
1460
+
1461
+ with gr.Row():
1462
+ with gr.Column():
1463
+ with gr.Row():
1464
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
1465
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=50, step=1, interactive=False)
1466
+
1467
+ resize_method = gr.Radio(
1468
+ ["Generate by", "Resize according to Reference"],
1469
+ value="Generate by",
1470
+ show_label=False,
1471
+ )
1472
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
1473
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
1474
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
1475
+
1476
+ with gr.Group():
1477
+ generation_method = gr.Radio(
1478
+ ["Video Generation", "Image Generation"],
1479
+ value="Video Generation",
1480
+ show_label=False,
1481
+ visible=True,
1482
+ )
1483
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=5, maximum=49, step=4)
1484
+
1485
+ source_method = gr.Radio(
1486
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
1487
+ value="Text to Video (文本到视频)",
1488
+ show_label=False,
1489
+ )
1490
+ with gr.Column(visible = False) as image_to_video_col:
1491
+ start_image = gr.Image(label="The image at the beginning of the video", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
1492
+
1493
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1494
+ def select_template(evt: gr.SelectData):
1495
+ text = {
1496
+ "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1497
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1498
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1499
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1500
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1501
+ }[template_gallery_path[evt.index]]
1502
+ return template_gallery_path[evt.index], text
1503
+
1504
+ template_gallery = gr.Gallery(
1505
+ template_gallery_path,
1506
+ columns=5, rows=1,
1507
+ height=140,
1508
+ allow_preview=False,
1509
+ container=False,
1510
+ label="Template Examples",
1511
+ )
1512
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
1513
+
1514
+ with gr.Accordion("The image at the ending of the video (Optional)", open=False):
1515
+ end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
1516
+
1517
+ with gr.Column(visible = False) as video_to_video_col:
1518
+ with gr.Row():
1519
+ validation_video = gr.Video(
1520
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
1521
+ elem_id="v2v", sources="upload",
1522
+ )
1523
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
1524
+ gr.Markdown(
1525
+ """
1526
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
1527
+ - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
1528
+ """
1529
+ )
1530
+ validation_video_mask = gr.Image(
1531
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
1532
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
1533
+ )
1534
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
1535
+
1536
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
1537
+
1538
+ with gr.Row():
1539
+ seed_textbox = gr.Textbox(label="Seed", value=43)
1540
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
1541
+ seed_button.click(
1542
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
1543
+ inputs=[],
1544
+ outputs=[seed_textbox]
1545
+ )
1546
+
1547
+ generate_button = gr.Button(value="Generate", variant='primary')
1548
+
1549
+ with gr.Column():
1550
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
1551
+ result_video = gr.Video(label="Generated Animation", interactive=False)
1552
+ infer_progress = gr.Textbox(
1553
+ label="Generation Info",
1554
+ value="No task currently",
1555
+ interactive=False
1556
+ )
1557
+
1558
+ def upload_generation_method(generation_method):
1559
+ if generation_method == "Video Generation":
1560
+ return gr.update(visible=True, minimum=5, maximum=49, value=49, interactive=True)
1561
+ elif generation_method == "Image Generation":
1562
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1563
+ generation_method.change(
1564
+ upload_generation_method, generation_method, [length_slider]
1565
+ )
1566
+
1567
+ def upload_source_method(source_method):
1568
+ if source_method == "Text to Video (文本到视频)":
1569
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1570
+ elif source_method == "Image to Video (图片到视频)":
1571
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
1572
+ else:
1573
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
1574
+ source_method.change(
1575
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
1576
+ )
1577
+
1578
+ def upload_resize_method(resize_method):
1579
+ if resize_method == "Generate by":
1580
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
1581
+ else:
1582
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
1583
+ resize_method.change(
1584
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
1585
+ )
1586
+
1587
+ generate_button.click(
1588
+ fn=controller.generate,
1589
+ inputs=[
1590
+ diffusion_transformer_dropdown,
1591
+ base_model_dropdown,
1592
+ lora_model_dropdown,
1593
+ lora_alpha_slider,
1594
+ prompt_textbox,
1595
+ negative_prompt_textbox,
1596
+ sampler_dropdown,
1597
+ sample_step_slider,
1598
+ resize_method,
1599
+ width_slider,
1600
+ height_slider,
1601
+ base_resolution,
1602
+ generation_method,
1603
+ length_slider,
1604
+ cfg_scale_slider,
1605
+ start_image,
1606
+ end_image,
1607
+ validation_video,
1608
+ validation_video_mask,
1609
+ denoise_strength,
1610
+ seed_textbox,
1611
+ ],
1612
+ outputs=[result_image, result_video, infer_progress]
1613
+ )
1614
+ return demo, controller
cogvideox/utils/__init__.py ADDED
File without changes
cogvideox/utils/lora_utils.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss
6
+
7
+ import hashlib
8
+ import math
9
+ import os
10
+ from collections import defaultdict
11
+ from io import BytesIO
12
+ from typing import List, Optional, Type, Union
13
+
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
18
+ from safetensors.torch import load_file
19
+ from transformers import T5EncoderModel
20
+
21
+
22
+ class LoRAModule(torch.nn.Module):
23
+ """
24
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lora_name,
30
+ org_module: torch.nn.Module,
31
+ multiplier=1.0,
32
+ lora_dim=4,
33
+ alpha=1,
34
+ dropout=None,
35
+ rank_dropout=None,
36
+ module_dropout=None,
37
+ ):
38
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
39
+ super().__init__()
40
+ self.lora_name = lora_name
41
+
42
+ if org_module.__class__.__name__ == "Conv2d":
43
+ in_dim = org_module.in_channels
44
+ out_dim = org_module.out_channels
45
+ else:
46
+ in_dim = org_module.in_features
47
+ out_dim = org_module.out_features
48
+
49
+ self.lora_dim = lora_dim
50
+ if org_module.__class__.__name__ == "Conv2d":
51
+ kernel_size = org_module.kernel_size
52
+ stride = org_module.stride
53
+ padding = org_module.padding
54
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
55
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
56
+ else:
57
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
58
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
59
+
60
+ if type(alpha) == torch.Tensor:
61
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
62
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
63
+ self.scale = alpha / self.lora_dim
64
+ self.register_buffer("alpha", torch.tensor(alpha))
65
+
66
+ # same as microsoft's
67
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68
+ torch.nn.init.zeros_(self.lora_up.weight)
69
+
70
+ self.multiplier = multiplier
71
+ self.org_module = org_module # remove in applying
72
+ self.dropout = dropout
73
+ self.rank_dropout = rank_dropout
74
+ self.module_dropout = module_dropout
75
+
76
+ def apply_to(self):
77
+ self.org_forward = self.org_module.forward
78
+ self.org_module.forward = self.forward
79
+ del self.org_module
80
+
81
+ def forward(self, x, *args, **kwargs):
82
+ weight_dtype = x.dtype
83
+ org_forwarded = self.org_forward(x)
84
+
85
+ # module dropout
86
+ if self.module_dropout is not None and self.training:
87
+ if torch.rand(1) < self.module_dropout:
88
+ return org_forwarded
89
+
90
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
91
+
92
+ # normal dropout
93
+ if self.dropout is not None and self.training:
94
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
95
+
96
+ # rank dropout
97
+ if self.rank_dropout is not None and self.training:
98
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
99
+ if len(lx.size()) == 3:
100
+ mask = mask.unsqueeze(1) # for Text Encoder
101
+ elif len(lx.size()) == 4:
102
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
103
+ lx = lx * mask
104
+
105
+ # scaling for rank dropout: treat as if the rank is changed
106
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
107
+ else:
108
+ scale = self.scale
109
+
110
+ lx = self.lora_up(lx)
111
+
112
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
113
+
114
+
115
+ def addnet_hash_legacy(b):
116
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
117
+ m = hashlib.sha256()
118
+
119
+ b.seek(0x100000)
120
+ m.update(b.read(0x10000))
121
+ return m.hexdigest()[0:8]
122
+
123
+
124
+ def addnet_hash_safetensors(b):
125
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
126
+ hash_sha256 = hashlib.sha256()
127
+ blksize = 1024 * 1024
128
+
129
+ b.seek(0)
130
+ header = b.read(8)
131
+ n = int.from_bytes(header, "little")
132
+
133
+ offset = n + 8
134
+ b.seek(offset)
135
+ for chunk in iter(lambda: b.read(blksize), b""):
136
+ hash_sha256.update(chunk)
137
+
138
+ return hash_sha256.hexdigest()
139
+
140
+
141
+ def precalculate_safetensors_hashes(tensors, metadata):
142
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
143
+ save time on indexing the model later."""
144
+
145
+ # Because writing user metadata to the file can change the result of
146
+ # sd_models.model_hash(), only retain the training metadata for purposes of
147
+ # calculating the hash, as they are meant to be immutable
148
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
149
+
150
+ bytes = safetensors.torch.save(tensors, metadata)
151
+ b = BytesIO(bytes)
152
+
153
+ model_hash = addnet_hash_safetensors(b)
154
+ legacy_hash = addnet_hash_legacy(b)
155
+ return model_hash, legacy_hash
156
+
157
+
158
+ class LoRANetwork(torch.nn.Module):
159
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel"]
160
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"]
161
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
162
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
+ def __init__(
164
+ self,
165
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
166
+ unet,
167
+ multiplier: float = 1.0,
168
+ lora_dim: int = 4,
169
+ alpha: float = 1,
170
+ dropout: Optional[float] = None,
171
+ module_class: Type[object] = LoRAModule,
172
+ add_lora_in_attn_temporal: bool = False,
173
+ varbose: Optional[bool] = False,
174
+ ) -> None:
175
+ super().__init__()
176
+ self.multiplier = multiplier
177
+
178
+ self.lora_dim = lora_dim
179
+ self.alpha = alpha
180
+ self.dropout = dropout
181
+
182
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
183
+ print(f"neuron dropout: p={self.dropout}")
184
+
185
+ # create module instances
186
+ def create_modules(
187
+ is_unet: bool,
188
+ root_module: torch.nn.Module,
189
+ target_replace_modules: List[torch.nn.Module],
190
+ ) -> List[LoRAModule]:
191
+ prefix = (
192
+ self.LORA_PREFIX_TRANSFORMER
193
+ if is_unet
194
+ else self.LORA_PREFIX_TEXT_ENCODER
195
+ )
196
+ loras = []
197
+ skipped = []
198
+ for name, module in root_module.named_modules():
199
+ if module.__class__.__name__ in target_replace_modules:
200
+ for child_name, child_module in module.named_modules():
201
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
202
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
203
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
204
+
205
+ if not add_lora_in_attn_temporal:
206
+ if "attn_temporal" in child_name:
207
+ continue
208
+
209
+ if is_linear or is_conv2d:
210
+ lora_name = prefix + "." + name + "." + child_name
211
+ lora_name = lora_name.replace(".", "_")
212
+
213
+ dim = None
214
+ alpha = None
215
+
216
+ if is_linear or is_conv2d_1x1:
217
+ dim = self.lora_dim
218
+ alpha = self.alpha
219
+
220
+ if dim is None or dim == 0:
221
+ if is_linear or is_conv2d_1x1:
222
+ skipped.append(lora_name)
223
+ continue
224
+
225
+ lora = module_class(
226
+ lora_name,
227
+ child_module,
228
+ self.multiplier,
229
+ dim,
230
+ alpha,
231
+ dropout=dropout,
232
+ )
233
+ loras.append(lora)
234
+ return loras, skipped
235
+
236
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
237
+
238
+ self.text_encoder_loras = []
239
+ skipped_te = []
240
+ for i, text_encoder in enumerate(text_encoders):
241
+ if text_encoder is not None:
242
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
243
+ self.text_encoder_loras.extend(text_encoder_loras)
244
+ skipped_te += skipped
245
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
246
+
247
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
248
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
249
+
250
+ # assertion
251
+ names = set()
252
+ for lora in self.text_encoder_loras + self.unet_loras:
253
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
254
+ names.add(lora.lora_name)
255
+
256
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
257
+ if apply_text_encoder:
258
+ print("enable LoRA for text encoder")
259
+ else:
260
+ self.text_encoder_loras = []
261
+
262
+ if apply_unet:
263
+ print("enable LoRA for U-Net")
264
+ else:
265
+ self.unet_loras = []
266
+
267
+ for lora in self.text_encoder_loras + self.unet_loras:
268
+ lora.apply_to()
269
+ self.add_module(lora.lora_name, lora)
270
+
271
+ def set_multiplier(self, multiplier):
272
+ self.multiplier = multiplier
273
+ for lora in self.text_encoder_loras + self.unet_loras:
274
+ lora.multiplier = self.multiplier
275
+
276
+ def load_weights(self, file):
277
+ if os.path.splitext(file)[1] == ".safetensors":
278
+ from safetensors.torch import load_file
279
+
280
+ weights_sd = load_file(file)
281
+ else:
282
+ weights_sd = torch.load(file, map_location="cpu")
283
+ info = self.load_state_dict(weights_sd, False)
284
+ return info
285
+
286
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
287
+ self.requires_grad_(True)
288
+ all_params = []
289
+
290
+ def enumerate_params(loras):
291
+ params = []
292
+ for lora in loras:
293
+ params.extend(lora.parameters())
294
+ return params
295
+
296
+ if self.text_encoder_loras:
297
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
298
+ if text_encoder_lr is not None:
299
+ param_data["lr"] = text_encoder_lr
300
+ all_params.append(param_data)
301
+
302
+ if self.unet_loras:
303
+ param_data = {"params": enumerate_params(self.unet_loras)}
304
+ if unet_lr is not None:
305
+ param_data["lr"] = unet_lr
306
+ all_params.append(param_data)
307
+
308
+ return all_params
309
+
310
+ def enable_gradient_checkpointing(self):
311
+ pass
312
+
313
+ def get_trainable_params(self):
314
+ return self.parameters()
315
+
316
+ def save_weights(self, file, dtype, metadata):
317
+ if metadata is not None and len(metadata) == 0:
318
+ metadata = None
319
+
320
+ state_dict = self.state_dict()
321
+
322
+ if dtype is not None:
323
+ for key in list(state_dict.keys()):
324
+ v = state_dict[key]
325
+ v = v.detach().clone().to("cpu").to(dtype)
326
+ state_dict[key] = v
327
+
328
+ if os.path.splitext(file)[1] == ".safetensors":
329
+ from safetensors.torch import save_file
330
+
331
+ # Precalculate model hashes to save time on indexing
332
+ if metadata is None:
333
+ metadata = {}
334
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
335
+ metadata["sshs_model_hash"] = model_hash
336
+ metadata["sshs_legacy_hash"] = legacy_hash
337
+
338
+ save_file(state_dict, file, metadata)
339
+ else:
340
+ torch.save(state_dict, file)
341
+
342
+ def create_network(
343
+ multiplier: float,
344
+ network_dim: Optional[int],
345
+ network_alpha: Optional[float],
346
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
347
+ transformer,
348
+ neuron_dropout: Optional[float] = None,
349
+ add_lora_in_attn_temporal: bool = False,
350
+ **kwargs,
351
+ ):
352
+ if network_dim is None:
353
+ network_dim = 4 # default
354
+ if network_alpha is None:
355
+ network_alpha = 1.0
356
+
357
+ network = LoRANetwork(
358
+ text_encoder,
359
+ transformer,
360
+ multiplier=multiplier,
361
+ lora_dim=network_dim,
362
+ alpha=network_alpha,
363
+ dropout=neuron_dropout,
364
+ add_lora_in_attn_temporal=add_lora_in_attn_temporal,
365
+ varbose=True,
366
+ )
367
+ return network
368
+
369
+ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
370
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
371
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
372
+ if state_dict is None:
373
+ state_dict = load_file(lora_path, device=device)
374
+ else:
375
+ state_dict = state_dict
376
+ updates = defaultdict(dict)
377
+ for key, value in state_dict.items():
378
+ layer, elem = key.split('.', 1)
379
+ updates[layer][elem] = value
380
+
381
+ for layer, elems in updates.items():
382
+
383
+ if "lora_te" in layer:
384
+ if transformer_only:
385
+ continue
386
+ else:
387
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
388
+ curr_layer = pipeline.text_encoder
389
+ else:
390
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
391
+ curr_layer = pipeline.transformer
392
+
393
+ temp_name = layer_infos.pop(0)
394
+ while len(layer_infos) > -1:
395
+ try:
396
+ curr_layer = curr_layer.__getattr__(temp_name)
397
+ if len(layer_infos) > 0:
398
+ temp_name = layer_infos.pop(0)
399
+ elif len(layer_infos) == 0:
400
+ break
401
+ except Exception:
402
+ if len(layer_infos) == 0:
403
+ print('Error loading layer')
404
+ if len(temp_name) > 0:
405
+ temp_name += "_" + layer_infos.pop(0)
406
+ else:
407
+ temp_name = layer_infos.pop(0)
408
+
409
+ weight_up = elems['lora_up.weight'].to(dtype)
410
+ weight_down = elems['lora_down.weight'].to(dtype)
411
+ if 'alpha' in elems.keys():
412
+ alpha = elems['alpha'].item() / weight_up.shape[1]
413
+ else:
414
+ alpha = 1.0
415
+
416
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
417
+ if len(weight_up.shape) == 4:
418
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
419
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(
420
+ 2).unsqueeze(3)
421
+ else:
422
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
423
+
424
+ return pipeline
425
+
426
+ # TODO: Refactor with merge_lora.
427
+ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
428
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
429
+ LORA_PREFIX_UNET = "lora_unet"
430
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
431
+ state_dict = load_file(lora_path, device=device)
432
+
433
+ updates = defaultdict(dict)
434
+ for key, value in state_dict.items():
435
+ layer, elem = key.split('.', 1)
436
+ updates[layer][elem] = value
437
+
438
+ for layer, elems in updates.items():
439
+
440
+ if "lora_te" in layer:
441
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
442
+ curr_layer = pipeline.text_encoder
443
+ else:
444
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
445
+ curr_layer = pipeline.transformer
446
+
447
+ temp_name = layer_infos.pop(0)
448
+ while len(layer_infos) > -1:
449
+ try:
450
+ curr_layer = curr_layer.__getattr__(temp_name)
451
+ if len(layer_infos) > 0:
452
+ temp_name = layer_infos.pop(0)
453
+ elif len(layer_infos) == 0:
454
+ break
455
+ except Exception:
456
+ if len(layer_infos) == 0:
457
+ print('Error loading layer')
458
+ if len(temp_name) > 0:
459
+ temp_name += "_" + layer_infos.pop(0)
460
+ else:
461
+ temp_name = layer_infos.pop(0)
462
+
463
+ weight_up = elems['lora_up.weight'].to(dtype)
464
+ weight_down = elems['lora_down.weight'].to(dtype)
465
+ if 'alpha' in elems.keys():
466
+ alpha = elems['alpha'].item() / weight_up.shape[1]
467
+ else:
468
+ alpha = 1.0
469
+
470
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
471
+ if len(weight_up.shape) == 4:
472
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
473
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
474
+ else:
475
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
476
+
477
+ return pipeline
cogvideox/utils/utils.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import imageio
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ import cv2
8
+ from einops import rearrange
9
+ from PIL import Image
10
+
11
+ def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
12
+ target_pixels = int(base_resolution) * int(base_resolution)
13
+ original_width, original_height = Image.open(image).size
14
+ ratio = (target_pixels / (original_width * original_height)) ** 0.5
15
+ width_slider = round(original_width * ratio)
16
+ height_slider = round(original_height * ratio)
17
+ return height_slider, width_slider
18
+
19
+ def color_transfer(sc, dc):
20
+ """
21
+ Transfer color distribution from of sc, referred to dc.
22
+
23
+ Args:
24
+ sc (numpy.ndarray): input image to be transfered.
25
+ dc (numpy.ndarray): reference image
26
+
27
+ Returns:
28
+ numpy.ndarray: Transferred color distribution on the sc.
29
+ """
30
+
31
+ def get_mean_and_std(img):
32
+ x_mean, x_std = cv2.meanStdDev(img)
33
+ x_mean = np.hstack(np.around(x_mean, 2))
34
+ x_std = np.hstack(np.around(x_std, 2))
35
+ return x_mean, x_std
36
+
37
+ sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
38
+ s_mean, s_std = get_mean_and_std(sc)
39
+ dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
40
+ t_mean, t_std = get_mean_and_std(dc)
41
+ img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
42
+ np.putmask(img_n, img_n > 255, 255)
43
+ np.putmask(img_n, img_n < 0, 0)
44
+ dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
45
+ return dst
46
+
47
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
48
+ videos = rearrange(videos, "b c t h w -> t b c h w")
49
+ outputs = []
50
+ for x in videos:
51
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
52
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
53
+ if rescale:
54
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
55
+ x = (x * 255).numpy().astype(np.uint8)
56
+ outputs.append(Image.fromarray(x))
57
+
58
+ if color_transfer_post_process:
59
+ for i in range(1, len(outputs)):
60
+ outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
61
+
62
+ os.makedirs(os.path.dirname(path), exist_ok=True)
63
+ if imageio_backend:
64
+ if path.endswith("mp4"):
65
+ imageio.mimsave(path, outputs, fps=fps)
66
+ else:
67
+ imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
68
+ else:
69
+ if path.endswith("mp4"):
70
+ path = path.replace('.mp4', '.gif')
71
+ outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
72
+
73
+ def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
74
+ if validation_image_start is not None and validation_image_end is not None:
75
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
76
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
77
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
78
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
79
+ else:
80
+ image_start = clip_image = validation_image_start
81
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
82
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
83
+
84
+ if type(validation_image_end) is str and os.path.isfile(validation_image_end):
85
+ image_end = Image.open(validation_image_end).convert("RGB")
86
+ image_end = image_end.resize([sample_size[1], sample_size[0]])
87
+ else:
88
+ image_end = validation_image_end
89
+ image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
90
+
91
+ if type(image_start) is list:
92
+ clip_image = clip_image[0]
93
+ start_video = torch.cat(
94
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
95
+ dim=2
96
+ )
97
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
98
+ input_video[:, :, :len(image_start)] = start_video
99
+
100
+ input_video_mask = torch.zeros_like(input_video[:, :1])
101
+ input_video_mask[:, :, len(image_start):] = 255
102
+ else:
103
+ input_video = torch.tile(
104
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
105
+ [1, 1, video_length, 1, 1]
106
+ )
107
+ input_video_mask = torch.zeros_like(input_video[:, :1])
108
+ input_video_mask[:, :, 1:] = 255
109
+
110
+ if type(image_end) is list:
111
+ image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
112
+ end_video = torch.cat(
113
+ [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
114
+ dim=2
115
+ )
116
+ input_video[:, :, -len(end_video):] = end_video
117
+
118
+ input_video_mask[:, :, -len(image_end):] = 0
119
+ else:
120
+ image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
121
+ input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
122
+ input_video_mask[:, :, -1:] = 0
123
+
124
+ input_video = input_video / 255
125
+
126
+ elif validation_image_start is not None:
127
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
128
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
129
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
130
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
131
+ else:
132
+ image_start = clip_image = validation_image_start
133
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
134
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
135
+ image_end = None
136
+
137
+ if type(image_start) is list:
138
+ clip_image = clip_image[0]
139
+ start_video = torch.cat(
140
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
141
+ dim=2
142
+ )
143
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
144
+ input_video[:, :, :len(image_start)] = start_video
145
+ input_video = input_video / 255
146
+
147
+ input_video_mask = torch.zeros_like(input_video[:, :1])
148
+ input_video_mask[:, :, len(image_start):] = 255
149
+ else:
150
+ input_video = torch.tile(
151
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
152
+ [1, 1, video_length, 1, 1]
153
+ ) / 255
154
+ input_video_mask = torch.zeros_like(input_video[:, :1])
155
+ input_video_mask[:, :, 1:, ] = 255
156
+ else:
157
+ image_start = None
158
+ image_end = None
159
+ input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
160
+ input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
161
+ clip_image = None
162
+
163
+ del image_start
164
+ del image_end
165
+ gc.collect()
166
+
167
+ return input_video, input_video_mask, clip_image
168
+
169
+ def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None):
170
+ if isinstance(input_video_path, str):
171
+ cap = cv2.VideoCapture(input_video_path)
172
+ input_video = []
173
+
174
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
175
+ frame_skip = 1 if fps is None else int(original_fps // fps)
176
+
177
+ frame_count = 0
178
+
179
+ while True:
180
+ ret, frame = cap.read()
181
+ if not ret:
182
+ break
183
+
184
+ if frame_count % frame_skip == 0:
185
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
186
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
187
+
188
+ frame_count += 1
189
+
190
+ cap.release()
191
+ else:
192
+ input_video = input_video_path
193
+
194
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
195
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
196
+
197
+ if validation_video_mask is not None:
198
+ validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
199
+ input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
200
+
201
+ input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
202
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
203
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
204
+ else:
205
+ input_video_mask = torch.zeros_like(input_video[:, :1])
206
+ input_video_mask[:, :, :] = 255
207
+
208
+ return input_video, input_video_mask, None
cogvideox/video_caption/README.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Video Caption
2
+ English | [简体中文](./README_zh-CN.md)
3
+
4
+ The folder contains codes for dataset preprocessing (i.e., video splitting, filtering, and recaptioning), and beautiful prompt used by CogVideoX-Fun.
5
+ The entire process supports distributed parallel processing, capable of handling large-scale datasets.
6
+
7
+ Meanwhile, we are collaborating with [Data-Juicer](https://github.com/modelscope/data-juicer/blob/main/docs/DJ_SORA.md),
8
+ allowing you to easily perform video data processing on [Aliyun PAI-DLC](https://help.aliyun.com/zh/pai/user-guide/video-preprocessing/).
9
+
10
+ # Table of Content
11
+ - [Video Caption](#video-caption)
12
+ - [Table of Content](#table-of-content)
13
+ - [Quick Start](#quick-start)
14
+ - [Setup](#setup)
15
+ - [Data Preprocessing](#data-preprocessing)
16
+ - [Data Preparation](#data-preparation)
17
+ - [Video Splitting](#video-splitting)
18
+ - [Video Filtering](#video-filtering)
19
+ - [Video Recaptioning](#video-recaptioning)
20
+ - [Beautiful Prompt (For CogVideoX-Fun Inference)](#beautiful-prompt-for-cogvideox-inference)
21
+ - [Batched Inference](#batched-inference)
22
+ - [OpenAI Server](#openai-server)
23
+
24
+ ## Quick Start
25
+
26
+ ### Setup
27
+ AliyunDSW or Docker is recommended to setup the environment, please refer to [Quick Start](../../README.md#quick-start).
28
+ You can also refer to the image build process in the [Dockerfile](../../Dockerfile.ds) to configure the conda environment and other dependencies locally.
29
+
30
+ Since the video recaptioning depends on [llm-awq](https://github.com/mit-han-lab/llm-awq) for faster and memory efficient inference,
31
+ the minimum GPU requirment should be RTX 3060 or A2 (CUDA Compute Capability >= 8.0).
32
+
33
+ ```shell
34
+ # pull image
35
+ docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
36
+
37
+ # enter image
38
+ docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
39
+
40
+ # clone code
41
+ git clone https://github.com/aigc-apps/CogVideoX-Fun.git
42
+
43
+ # enter video_caption
44
+ cd CogVideoX-Fun/cogvideox/video_caption
45
+ ```
46
+
47
+ ### Data Preprocessing
48
+ #### Data Preparation
49
+ Place the downloaded videos into a folder under [datasets](./datasets/) (preferably without nested structures, as the video names are used as unique IDs in subsequent processes).
50
+ Taking Panda-70M as an example, the entire dataset directory structure is shown as follows:
51
+ ```
52
+ 📦 datasets/
53
+ ├── 📂 panda_70m/
54
+ │ ├── 📂 videos/
55
+ │ │ ├── 📂 data/
56
+ │ │ │ └── 📄 --C66yU3LjM_2.mp4
57
+ │ │ │ └── 📄 ...
58
+ ```
59
+
60
+ #### Video Splitting
61
+ CogVideoX-Fun utilizes [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) to identify scene changes within the video
62
+ and performs video splitting via FFmpeg based on certain threshold values to ensure consistency of the video clip.
63
+ Video clips shorter than 3 seconds will be discarded, and those longer than 10 seconds will be splitted recursively.
64
+
65
+ The entire workflow of video splitting is in the [stage_1_video_splitting.sh](./scripts/stage_1_video_splitting.sh).
66
+ After running
67
+ ```shell
68
+ sh scripts/stage_1_video_splitting.sh
69
+ ```
70
+ the video clips are obtained in `cogvideox/video_caption/datasets/panda_70m/videos_clips/data/`.
71
+
72
+ #### Video Filtering
73
+ Based on the videos obtained in the previous step, CogVideoX-Fun provides a simple yet effective pipeline to filter out high-quality videos for recaptioning.
74
+ The overall process is as follows:
75
+
76
+ - Aesthetic filtering: Filter out videos with poor content (blurry, dim, etc.) by calculating the average aesthetic score of uniformly sampled 4 frames via [aesthetic-predictor-v2-5](https://github.com/discus0434/aesthetic-predictor-v2-5).
77
+ - Text filtering: Use [EasyOCR](https://github.com/JaidedAI/EasyOCR) to calculate the text area proportion of the middle frame to filter out videos with a large area of text.
78
+ - Motion filtering: Calculate interframe optical flow differences to filter out videos that move too slowly or too quickly.
79
+
80
+ The entire workflow of video filtering is in the [stage_2_video_filtering.sh](./scripts/stage_2_video_filtering.sh).
81
+ After running
82
+ ```shell
83
+ sh scripts/stage_2_video_filtering.sh
84
+ ```
85
+ the aesthetic score, text score, and motion score of videos will be saved in the corresponding meta files in the folder `cogvideox/video_caption/datasets/panda_70m/videos_clips/`.
86
+
87
+ > [!NOTE]
88
+ > The computation of the aesthetic score depends on the [google/siglip-so400m-patch14-384 model](https://huggingface.co/google/siglip-so400m-patch14-384).
89
+ Please run `HF_ENDPOINT=https://hf-mirror.com sh scripts/stage_2_video_filtering.sh` if you cannot access to huggingface.com.
90
+
91
+
92
+ #### Video Recaptioning
93
+ After obtaining the aboved high-quality filtered videos, CogVideoX-Fun utilizes [VILA1.5](https://github.com/NVlabs/VILA) to perform video recaptioning.
94
+ Subsequently, the recaptioning results are rewritten by LLMs to better meet with the requirements of video generation tasks.
95
+ Finally, an advanced VideoCLIPXL model is developed to filter out video-caption pairs with poor alignment, resulting in the final training dataset.
96
+
97
+ Please download the video caption model from [VILA1.5](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e) of the appropriate size based on the GPU memory of your machine.
98
+ For A100 with 40G VRAM, you can download [VILA1.5-40b-AWQ](https://huggingface.co/Efficient-Large-Model/VILA1.5-40b-AWQ) by running
99
+ ```shell
100
+ # Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
101
+ huggingface-cli download Efficient-Large-Model/VILA1.5-40b-AWQ --local-dir-use-symlinks False --local-dir /PATH/TO/VILA_MODEL
102
+ ```
103
+
104
+ Optionally, you can prepare local LLMs to rewrite the recaption results.
105
+ For example, you can download [Meta-Llama-3-8B-Instruct](https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct) by running
106
+ ```shell
107
+ # Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
108
+ huggingface-cli download NousResearch/Meta-Llama-3-8B-Instruct --local-dir-use-symlinks False --local-dir /PATH/TO/REWRITE_MODEL
109
+ ```
110
+
111
+ The entire workflow of video recaption is in the [stage_3_video_recaptioning.sh](./scripts/stage_3_video_recaptioning.sh).
112
+ After running
113
+ ```shell
114
+ VILA_MODEL_PATH=/PATH/TO/VILA_MODEL REWRITE_MODEL_PATH=/PATH/TO/REWRITE_MODEL sh scripts/stage_3_video_recaptioning.sh
115
+ ```
116
+ the final train file is obtained in `cogvideox/video_caption/datasets/panda_70m/videos_clips/meta_train_info.json`.
117
+
118
+
119
+ ### Beautiful Prompt (For CogVideoX-Fun Inference)
120
+ Beautiful Prompt aims to rewrite and beautify the user-uploaded prompt via LLMs, mapping it to the style of CogVideoX-Fun's training captions,
121
+ making it more suitable as the inference prompt and thus improving the quality of the generated videos.
122
+ We support batched inference with local LLMs or OpenAI compatible server based on [vLLM](https://github.com/vllm-project/vllm) for beautiful prompt.
123
+
124
+ #### Batched Inference
125
+ 1. Prepare original prompts in a jsonl file `cogvideox/video_caption/datasets/original_prompt.jsonl` with the following format:
126
+ ```json
127
+ {"prompt": "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."}
128
+ {"prompt": "An underwater world with realistic fish and other creatures of the sea."}
129
+ {"prompt": "a monarch butterfly perched on a tree trunk in the forest."}
130
+ {"prompt": "a child in a room with a bottle of wine and a lamp."}
131
+ {"prompt": "two men in suits walking down a hallway."}
132
+ ```
133
+
134
+ 2. Then you can perform beautiful prompt by running
135
+ ```shell
136
+ # Meta-Llama-3-8B-Instruct is sufficient for this task.
137
+ # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
138
+
139
+ python caption_rewrite.py \
140
+ --video_metadata_path datasets/original_prompt.jsonl \
141
+ --caption_column "prompt" \
142
+ --batch_size 1 \
143
+ --model_name /path/to/your_llm \
144
+ --prompt prompt/beautiful_prompt.txt \
145
+ --prefix '"detailed description": ' \
146
+ --saved_path datasets/beautiful_prompt.jsonl \
147
+ --saved_freq 1
148
+ ```
149
+
150
+ #### OpenAI Server
151
+ + You can request OpenAI compatible server to perform beautiful prompt by running
152
+ ```shell
153
+ OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
154
+ --model "your_model_name" \
155
+ --prompt "your_prompt"
156
+ ```
157
+
158
+ + You can also deploy the OpenAI Compatible Server locally using vLLM. For example:
159
+ ```shell
160
+ # Meta-Llama-3-8B-Instruct is sufficient for this task.
161
+ # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
162
+
163
+ # deploy the OpenAI compatible server
164
+ python -m vllm.entrypoints.openai.api_server serve /path/to/your_llm --dtype auto --api-key "your_api_key"
165
+ ```
166
+
167
+ Then you can perform beautiful prompt by running
168
+ ```shell
169
+ python -m beautiful_prompt.py \
170
+ --model /path/to/your_llm \
171
+ --prompt "your_prompt" \
172
+ --base_url "http://localhost:8000/v1" \
173
+ --api_key "your_api_key"
174
+ ```
cogvideox/video_caption/README_zh-CN.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 数据预处理
2
+ [English](./README.md) | 简体中文
3
+
4
+ 该文件夹包含 CogVideoX-Fun 使用的数据集预处理(即视频切分、过滤和生成描述)和提示词美化的代码。整个过程支持分布式并行处理,能够处理大规模数据集。
5
+
6
+ 此外,我们和 [Data-Juicer](https://github.com/modelscope/data-juicer/blob/main/docs/DJ_SORA.md) 合作,能让你在 [Aliyun PAI-DLC](https://help.aliyun.com/zh/pai/user-guide/video-preprocessing/) 轻松进行视频数据的处理。
7
+
8
+ # 目录
9
+ - [数据预处理](#数据预处理)
10
+ - [目录](#目录)
11
+ - [快速开始](#快速开始)
12
+ - [安装](#安装)
13
+ - [数据集预处理](#数据集预处理)
14
+ - [数据准备](#数据准备)
15
+ - [视频切分](#视频切分)
16
+ - [视频过滤](#视频过滤)
17
+ - [视频描述](#视频描述)
18
+ - [提示词美化](#提示词美化)
19
+ - [批量推理](#批量推理)
20
+ - [OpenAI 服务器](#openai-服务器)
21
+
22
+
23
+ ## 快速开始
24
+ ### 安装
25
+ 推荐使用阿里云 DSW 和 Docker 来安装环境,请参考 [快速开始](../../README_zh-CN.md#1-云使用-aliyundswdocker). 你也可以参考 [Dockerfile](../../Dockerfile.ds) 中的镜像构建流程在本地安装对应的 conda 环境和其余依赖。
26
+
27
+ 为了提高推理速度和节省推理的显存,生成视频描述依赖于 [llm-awq](https://github.com/mit-han-lab/llm-awq)。因此,需要 RTX 3060 或者 A2 及以上的显卡 (CUDA Compute Capability >= 8.0)。
28
+
29
+ ```shell
30
+ # pull image
31
+ docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
32
+
33
+ # enter image
34
+ docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
35
+
36
+ # clone code
37
+ git clone https://github.com/aigc-apps/CogVideoX-Fun.git
38
+
39
+ # enter video_caption
40
+ cd CogVideoX-Fun/cogvideox/video_caption
41
+ ```
42
+
43
+ ### 数据集预处理
44
+ #### 数据准备
45
+ 将下载的视频准备到文件夹 [datasets](./datasets/)(最好不使用嵌套结构,因为视频名称在后续处理中用作唯一 ID)。以 Panda-70M 为例,完整的数据集目录结构如下所示:
46
+ ```
47
+ 📦 datasets/
48
+ ├── 📂 panda_70m/
49
+ │ ├── 📂 videos/
50
+ │ │ ├── 📂 data/
51
+ │ │ │ └── 📄 --C66yU3LjM_2.mp4
52
+ │ │ │ └── 📄 ...
53
+ ```
54
+
55
+ #### 视频切分
56
+ CogVideoX-Fun 使用 [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) 来识别视频中的场景变化
57
+ 并根据某些阈值通过 FFmpeg 执行视频分割,以确保视频片段的一致性。
58
+ 短于 3 秒的视频片段将被丢弃,长于 10 秒的视频片段将被递归切分。
59
+
60
+ 视频切分的完整流程在 [stage_1_video_splitting.sh](./scripts/stage_1_video_splitting.sh)。执行
61
+ ```shell
62
+ sh scripts/stage_1_video_splitting.sh
63
+ ```
64
+ 后,切分后的视频位于 `cogvideox/video_caption/datasets/panda_70m/videos_clips/data/`。
65
+
66
+ #### 视频过滤
67
+ 基于上一步获得的视频,CogVideoX-Fun 提供了一个简单而有效的流程来过滤出高质量的视频。总体流程如下:
68
+
69
+ - 美学过滤:通过 [aesthetic-predictor-v2-5](https://github.com/discus0434/aesthetic-predictor-v2-5) 计算均匀采样的 4 帧视频的平均美学分数,从而筛选出内容不佳(模糊、昏暗等)的视频。
70
+ - 文本过滤:使用 [EasyOCR](https://github.com/JaidedAI/EasyOCR) 计算中间帧的文本区域比例,过滤掉含有大面积文本的视频。
71
+ - 运动过滤:计算帧间光流差,过滤掉移动太慢或太快的视频。
72
+
73
+ 视频过滤的完整流程在 [stage_2_video_filtering.sh](./scripts/stage_2_video_filtering.sh)。执行
74
+ ```shell
75
+ sh scripts/stage_2_video_filtering.sh
76
+ ```
77
+ 后,视频的美学得分、文本得分和运动得分对应的元文件保存在 `cogvideox/video_caption/datasets/panda_70m/videos_clips/`。
78
+
79
+ > [!NOTE]
80
+ > 美学得分的计算依赖于 [google/siglip-so400m-patch14-384 model](https://huggingface.co/google/siglip-so400m-patch14-384).
81
+ 请执行 `HF_ENDPOINT=https://hf-mirror.com sh scripts/stage_2_video_filtering.sh` 如果你无法访问 huggingface.com.
82
+
83
+ #### 视频描述
84
+ 在获得上述高质量的过滤视频后,CogVideoX-Fun 利用 [VILA1.5](https://github.com/NVlabs/VILA) 来生成视频描述。随后,使用 LLMs 对生成的视频描述进行重写,以更好地满足视频生成任务的要求。最后,使用自研的 VideoCLIPXL 模型来过滤掉描述和视频内容不一致的数据,从而得到最终的训练数据集。
85
+
86
+ 请根据机器的显存从 [VILA1.5](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e) 下载合适大小的模型。对于 A100 40G,你可以执行下面的命令来下载 [VILA1.5-40b-AWQ](https://huggingface.co/Efficient-Large-Model/VILA1.5-40b-AWQ)
87
+ ```shell
88
+ # Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
89
+ huggingface-cli download Efficient-Large-Model/VILA1.5-40b-AWQ --local-dir-use-symlinks False --local-dir /PATH/TO/VILA_MODEL
90
+ ```
91
+
92
+ 你可以选择性地准备 LLMs 来改写上述视频描述的结果。例如,你执行下面的命令来下载 [Meta-Llama-3-8B-Instruct](https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct)
93
+ ```shell
94
+ # Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
95
+ huggingface-cli download NousResearch/Meta-Llama-3-8B-Instruct --local-dir-use-symlinks False --local-dir /PATH/TO/REWRITE_MODEL
96
+ ```
97
+
98
+ 视频描述的完整流程在 [stage_3_video_recaptioning.sh](./scripts/stage_3_video_recaptioning.sh).
99
+ 执行
100
+ ```shell
101
+ VILA_MODEL_PATH=/PATH/TO/VILA_MODEL REWRITE_MODEL_PATH=/PATH/TO/REWRITE_MODEL sh scripts/stage_3_video_recaptioning.sh
102
+ ```
103
+ 后,最后的训练文件会保存在 `cogvideox/video_caption/datasets/panda_70m/videos_clips/meta_train_info.json`。
104
+
105
+ ### 提示词美化
106
+ 提示词美化旨在通过 LLMs 重写和美化用户上传的提示,将其映射为 CogVideoX-Fun 训练所使用的视频描述风格、
107
+ 使其更适合用作推理提示词,从而提高生成视频的质量。
108
+
109
+ 基于 [vLLM](https://github.com/vllm-project/vllm),我们支持使用本地 LLM 进行批量推理或请求 OpenAI 服务器的方式,以进行提示词美化。
110
+
111
+ #### 批量推理
112
+ 1. 将原始的提示词以下面的格式准备在文件 `cogvideox/video_caption/datasets/original_prompt.jsonl` 中:
113
+ ```json
114
+ {"prompt": "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."}
115
+ {"prompt": "An underwater world with realistic fish and other creatures of the sea."}
116
+ {"prompt": "a monarch butterfly perched on a tree trunk in the forest."}
117
+ {"prompt": "a child in a room with a bottle of wine and a lamp."}
118
+ {"prompt": "two men in suits walking down a hallway."}
119
+ ```
120
+
121
+ 2. 随后你可以通过执行以下的命令进行提示词美化
122
+ ```shell
123
+ # Meta-Llama-3-8B-Instruct is sufficient for this task.
124
+ # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
125
+
126
+ python caption_rewrite.py \
127
+ --video_metadata_path datasets/original_prompt.jsonl \
128
+ --caption_column "prompt" \
129
+ --batch_size 1 \
130
+ --model_name /path/to/your_llm \
131
+ --prompt prompt/beautiful_prompt.txt \
132
+ --prefix '"detailed description": ' \
133
+ --saved_path datasets/beautiful_prompt.jsonl \
134
+ --saved_freq 1
135
+ ```
136
+
137
+ #### OpenAI 服务器
138
+ + 你可以通过请求 OpenAI 服务器的方式来进行提示词美化
139
+ ```shell
140
+ OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
141
+ --model "your_model_name" \
142
+ --prompt "your_prompt"
143
+ ```
144
+
145
+ + 你也可以执行以下命令,通过 vLLM 将本地 LLMs 部署成兼容 OpenAI 的服务器
146
+ ```shell
147
+ OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
148
+ --model "your_model_name" \
149
+ --prompt "your_prompt"
150
+ ```
151
+
152
+ 然后再执行下面的命令来进行提示词美化
153
+ ```shell
154
+ python -m beautiful_prompt.py \
155
+ --model /path/to/your_llm \
156
+ --prompt "your_prompt" \
157
+ --base_url "http://localhost:8000/v1" \
158
+ --api_key "your_api_key"
159
+ ```
cogvideox/video_caption/beautiful_prompt.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script (optional) can rewrite and beautify the user-uploaded prompt via LLMs, mapping it to the style of cogvideox's training captions,
3
+ making it more suitable as the inference prompt and thus improving the quality of the generated videos.
4
+
5
+ Usage:
6
+ + You can request OpenAI compatible server to perform beautiful prompt by running
7
+ ```shell
8
+ export OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
9
+ --model "your_model_name" \
10
+ --prompt "your_prompt"
11
+ ```
12
+ + You can also deploy the OpenAI Compatible Server locally using vLLM. For example:
13
+ ```shell
14
+ # Meta-Llama-3-8B-Instruct is sufficient for this task.
15
+ # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
16
+
17
+ # deploy the OpenAI compatible server
18
+ python -m vllm.entrypoints.openai.api_server serve /path/to/your_llm --dtype auto --api-key "your_api_key"
19
+ ```
20
+
21
+ Then you can perform beautiful prompt by running
22
+ ```shell
23
+ python -m beautiful_prompt.py \
24
+ --model /path/to/your_llm \
25
+ --prompt "your_prompt" \
26
+ --base_url "http://localhost:8000/v1" \
27
+ --api_key "your_api_key"
28
+ ```
29
+ """
30
+ import argparse
31
+ import os
32
+
33
+ from openai import OpenAI
34
+
35
+ from cogvideox.video_caption.caption_rewrite import extract_output
36
+
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(description="Beautiful prompt.")
40
+ parser.add_argument("--model", type=str, required=True, help="The OpenAI model or the path to your local LLM.")
41
+ parser.add_argument("--prompt", type=str, required=True, help="The user-uploaded prompt.")
42
+ parser.add_argument(
43
+ "--template",
44
+ type=str,
45
+ default="cogvideox/video_caption/prompt/beautiful_prompt.txt",
46
+ help="A string or a txt file contains the template for beautiful prompt."
47
+ )
48
+ parser.add_argument(
49
+ "--max_retry_nums",
50
+ type=int,
51
+ default=5,
52
+ help="Maximum number of retries to obtain an output that meets the JSON format."
53
+ )
54
+ parser.add_argument(
55
+ "--base_url",
56
+ type=str,
57
+ default=None,
58
+ help="OpenAI API server url. If it is None, the OPENAI_BASE_URL from the environment variables will be used.",
59
+ )
60
+ parser.add_argument(
61
+ "--api_key",
62
+ type=str,
63
+ default=None,
64
+ help="OpenAI API key. If it is None, the OPENAI_API_KEY from the environment variables will be used.",
65
+ )
66
+
67
+ args = parser.parse_args()
68
+ return args
69
+
70
+
71
+ def main():
72
+ args = parse_args()
73
+
74
+ client = OpenAI(
75
+ base_url=os.getenv("OPENAI_BASE_URL", args.base_url),
76
+ api_key=os.environ.get("OPENAI_API_KEY", args.api_key),
77
+ )
78
+ if args.template.endswith(".txt") and os.path.exists(args.template):
79
+ with open(args.template, "r") as f:
80
+ args.template = "".join(f.readlines())
81
+ # print(f"Beautiful prompt template: {args.template}")
82
+
83
+ for _ in range(args.max_retry_nums):
84
+ completion = client.chat.completions.create(
85
+ model=args.model,
86
+ messages=[
87
+ # {"role": "system", "content": "You are a helpful assistant."},
88
+ {"role": "user", "content": args.template + "\n" + str(args.prompt)}
89
+ ],
90
+ temperature=0.7,
91
+ top_p=1,
92
+ max_tokens=1024,
93
+ )
94
+
95
+ output = completion.choices[0].message.content
96
+ output = extract_output(output, prefix='"detailed description": ')
97
+ if output is not None:
98
+ break
99
+ print(f"Beautiful prompt: {output}")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
cogvideox/video_caption/caption_rewrite.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+ import os
4
+ from tqdm import tqdm
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from natsort import index_natsorted
9
+ from vllm import LLM, SamplingParams
10
+ from transformers import AutoTokenizer
11
+
12
+ from utils.logger import logger
13
+
14
+
15
+ def extract_output(s, prefix='"rewritten description": '):
16
+ """Customize the function according to the prompt."""
17
+ # Since some LLMs struggles to output strictly formatted JSON strings as specified by the prompt,
18
+ # thus manually parse the output string `{"rewritten description": "your rewritten description here"}`.
19
+ match = re.search(r"{(.+?)}", s, re.DOTALL)
20
+ if not match:
21
+ logger.warning(f"{s} is not in the json format. Return None.")
22
+ return None
23
+ output = match.group(1).strip()
24
+ if output.startswith(prefix):
25
+ output = output[len(prefix) :]
26
+ if output[0] == '"' and output[-1] == '"':
27
+ return output[1:-1]
28
+ else:
29
+ logger.warning(f"{output} does not start and end with the double quote. Return None.")
30
+ return None
31
+ else:
32
+ logger.warning(f"{output} does not start with {prefix}. Return None.")
33
+ return None
34
+
35
+
36
+ def parse_args():
37
+ parser = argparse.ArgumentParser(description="Rewrite the video caption by LLMs.")
38
+ parser.add_argument(
39
+ "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
40
+ )
41
+ parser.add_argument(
42
+ "--video_path_column",
43
+ type=str,
44
+ default=None,
45
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
46
+ )
47
+ parser.add_argument(
48
+ "--caption_column",
49
+ type=str,
50
+ default="caption",
51
+ help="The column contains the video caption.",
52
+ )
53
+ parser.add_argument(
54
+ "--batch_size",
55
+ type=int,
56
+ default=128,
57
+ required=False,
58
+ help="The batch size for vllm inference. Adjust according to the number of GPUs to maximize inference throughput.",
59
+ )
60
+ parser.add_argument(
61
+ "--model_name",
62
+ type=str,
63
+ default="NousResearch/Meta-Llama-3-8B-Instruct",
64
+ )
65
+ parser.add_argument(
66
+ "--prompt",
67
+ type=str,
68
+ required=True,
69
+ help="A string or a txt file contains the prompt.",
70
+ )
71
+ parser.add_argument(
72
+ "--prefix",
73
+ type=str,
74
+ required=True,
75
+ help="The prefix to extract the output from LLMs.",
76
+ )
77
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
78
+ parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.")
79
+
80
+ args = parser.parse_args()
81
+ return args
82
+
83
+
84
+ def main():
85
+ args = parse_args()
86
+
87
+ if args.video_metadata_path.endswith(".csv"):
88
+ video_metadata_df = pd.read_csv(args.video_metadata_path)
89
+ elif args.video_metadata_path.endswith(".jsonl"):
90
+ video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
91
+ elif args.video_metadata_path.endswith(".json"):
92
+ video_metadata_df = pd.read_json(args.video_metadata_path)
93
+ else:
94
+ raise ValueError(f"The {args.video_metadata_path} must end with .csv, .jsonl or .json.")
95
+
96
+ saved_suffix = os.path.splitext(args.saved_path)[1]
97
+ if saved_suffix not in set([".csv", ".jsonl", ".json"]):
98
+ raise ValueError(f"The saved_path must end with .csv, .jsonl or .json.")
99
+
100
+ if os.path.exists(args.saved_path) and args.video_path_column is not None:
101
+ if args.saved_path.endswith(".csv"):
102
+ saved_metadata_df = pd.read_csv(args.saved_path)
103
+ elif args.saved_path.endswith(".jsonl"):
104
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
105
+
106
+ # Filter out the unprocessed video-caption pairs by setting the indicator=True.
107
+ merged_df = video_metadata_df.merge(saved_metadata_df, on=args.video_path_column, how="outer", indicator=True)
108
+ video_metadata_df = merged_df[merged_df["_merge"] == "left_only"]
109
+ # Sorting to guarantee the same result for each process.
110
+ video_metadata_df = video_metadata_df.iloc[index_natsorted(video_metadata_df[args.video_path_column])].reset_index(
111
+ drop=True
112
+ )
113
+ logger.info(
114
+ f"Resume from {args.saved_path}: {len(saved_metadata_df)} processed and {len(video_metadata_df)} to be processed."
115
+ )
116
+
117
+ if args.prompt.endswith(".txt") and os.path.exists(args.prompt):
118
+ with open(args.prompt, "r") as f:
119
+ args.prompt = "".join(f.readlines())
120
+ logger.info(f"Prompt: {args.prompt}")
121
+
122
+ if args.video_path_column is not None:
123
+ video_path_list = video_metadata_df[args.video_path_column].tolist()
124
+ if args.caption_column in video_metadata_df.columns:
125
+ sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
126
+ else:
127
+ # When two columns with the same name, the dataframe merge operation on will distinguish them by adding 'x' and 'y'.
128
+ sampled_frame_caption_list = video_metadata_df[args.caption_column + "_x"].tolist()
129
+
130
+ CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES", None)
131
+ tensor_parallel_size = torch.cuda.device_count() if CUDA_VISIBLE_DEVICES is None else len(CUDA_VISIBLE_DEVICES.split(","))
132
+ logger.info(f"Automatically set tensor_parallel_size={tensor_parallel_size} based on the available devices.")
133
+
134
+ llm = LLM(model=args.model_name, trust_remote_code=True, tensor_parallel_size=tensor_parallel_size)
135
+ if "Meta-Llama-3" in args.model_name:
136
+ if "Meta-Llama-3-70B" in args.model_name:
137
+ # Llama-3-70B should use the tokenizer from Llama-3-8B
138
+ # https://github.com/vllm-project/vllm/issues/4180#issuecomment-2068292942
139
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
140
+ else:
141
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
142
+ stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
143
+ sampling_params = SamplingParams(temperature=0.7, top_p=1, max_tokens=1024, stop_token_ids=stop_token_ids)
144
+ else:
145
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
146
+ sampling_params = SamplingParams(temperature=0.7, top_p=1, max_tokens=1024)
147
+
148
+ result_dict = {args.caption_column: []}
149
+ if args.video_path_column is not None:
150
+ result_dict = {args.video_path_column: [], args.caption_column: []}
151
+
152
+ for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)):
153
+ if args.video_path_column is not None:
154
+ batch_video_path = video_path_list[i : i + args.batch_size]
155
+ batch_caption = sampled_frame_caption_list[i : i + args.batch_size]
156
+ batch_prompt = []
157
+ for caption in batch_caption:
158
+ # batch_prompt.append("user:" + args.prompt + str(caption) + "\n assistant:")
159
+ messages = [
160
+ {"role": "system", "content": "You are a helpful assistant."},
161
+ {"role": "user", "content": args.prompt + "\n" + str(caption)},
162
+ ]
163
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
164
+ batch_prompt.append(text)
165
+
166
+ batch_output = llm.generate(batch_prompt, sampling_params)
167
+ batch_output = [output.outputs[0].text.rstrip() for output in batch_output]
168
+ batch_output = [extract_output(output, prefix=args.prefix) for output in batch_output]
169
+
170
+ # Filter out data that does not meet the output format.
171
+ batch_result = []
172
+ if args.video_path_column is not None:
173
+ for video_path, output in zip(batch_video_path, batch_output):
174
+ if output is not None:
175
+ batch_result.append((video_path, output))
176
+ batch_video_path, batch_output = zip(*batch_result)
177
+
178
+ result_dict[args.video_path_column].extend(batch_video_path)
179
+ else:
180
+ for output in batch_output:
181
+ if output is not None:
182
+ batch_result.append(output)
183
+
184
+ result_dict[args.caption_column].extend(batch_result)
185
+
186
+ # Save the metadata every args.saved_freq.
187
+ if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0:
188
+ if len(result_dict[args.caption_column]) > 0:
189
+ result_df = pd.DataFrame(result_dict)
190
+ if args.saved_path.endswith(".csv"):
191
+ header = True if not os.path.exists(args.saved_path) else False
192
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
193
+ elif args.saved_path.endswith(".jsonl"):
194
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
195
+ elif args.saved_path.endswith(".json"):
196
+ # Append is not supported.
197
+ if os.path.exists(args.saved_path):
198
+ saved_df = pd.read_json(args.saved_path, orient="records")
199
+ result_df = pd.concat([saved_df, result_df], ignore_index=True)
200
+ result_df.to_json(args.saved_path, orient="records", indent=4, force_ascii=False)
201
+ logger.info(f"Save result to {args.saved_path}.")
202
+
203
+ result_dict = {args.caption_column: []}
204
+ if args.video_path_column is not None:
205
+ result_dict = {args.video_path_column: [], args.caption_column: []}
206
+
207
+ if len(result_dict[args.caption_column]) > 0:
208
+ result_df = pd.DataFrame(result_dict)
209
+ if args.saved_path.endswith(".csv"):
210
+ header = True if not os.path.exists(args.saved_path) else False
211
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
212
+ elif args.saved_path.endswith(".jsonl"):
213
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
214
+ elif args.saved_path.endswith(".json"):
215
+ # Append is not supported.
216
+ if os.path.exists(args.saved_path):
217
+ saved_df = pd.read_json(args.saved_path, orient="records")
218
+ result_df = pd.concat([saved_df, result_df], ignore_index=True)
219
+ result_df.to_json(args.saved_path, orient="records", indent=4, force_ascii=False)
220
+ logger.info(f"Save the final result to {args.saved_path}.")
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
cogvideox/video_caption/compute_motion_score.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import argparse
3
+ import gc
4
+ import os
5
+ from contextlib import contextmanager
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import pandas as pd
11
+ from joblib import Parallel, delayed
12
+ from natsort import natsorted
13
+ from tqdm import tqdm
14
+
15
+ from utils.logger import logger
16
+ from utils.filter import filter
17
+
18
+
19
+ @contextmanager
20
+ def VideoCapture(video_path):
21
+ cap = cv2.VideoCapture(video_path)
22
+ try:
23
+ yield cap
24
+ finally:
25
+ cap.release()
26
+ del cap
27
+ gc.collect()
28
+
29
+
30
+ def compute_motion_score(video_path):
31
+ video_motion_scores = []
32
+ sampling_fps = 2
33
+
34
+ try:
35
+ with VideoCapture(video_path) as cap:
36
+ fps = cap.get(cv2.CAP_PROP_FPS)
37
+ valid_fps = min(max(sampling_fps, 1), fps)
38
+ frame_interval = int(fps / valid_fps)
39
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
40
+
41
+ # if cannot get the second frame, use the last one
42
+ frame_interval = min(frame_interval, total_frames - 1)
43
+
44
+ prev_frame = None
45
+ frame_count = -1
46
+ while cap.isOpened():
47
+ ret, frame = cap.read()
48
+ frame_count += 1
49
+
50
+ if not ret:
51
+ break
52
+
53
+ # skip middle frames
54
+ if frame_count % frame_interval != 0:
55
+ continue
56
+
57
+ gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
58
+ if prev_frame is None:
59
+ prev_frame = gray_frame
60
+ continue
61
+
62
+ flow = cv2.calcOpticalFlowFarneback(
63
+ prev_frame,
64
+ gray_frame,
65
+ None,
66
+ pyr_scale=0.5,
67
+ levels=3,
68
+ winsize=15,
69
+ iterations=3,
70
+ poly_n=5,
71
+ poly_sigma=1.2,
72
+ flags=0,
73
+ )
74
+ mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
75
+ frame_motion_score = np.mean(mag)
76
+ video_motion_scores.append(frame_motion_score)
77
+ prev_frame = gray_frame
78
+
79
+ video_meta_info = {
80
+ "video_path": Path(video_path).name,
81
+ "motion_score": round(float(np.mean(video_motion_scores)), 5),
82
+ }
83
+ return video_meta_info
84
+
85
+ except Exception as e:
86
+ print(f"Compute motion score for video {video_path} with error: {e}.")
87
+
88
+
89
+ def parse_args():
90
+ parser = argparse.ArgumentParser(description="Compute the motion score of the videos.")
91
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
92
+ parser.add_argument(
93
+ "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
94
+ )
95
+ parser.add_argument(
96
+ "--video_path_column",
97
+ type=str,
98
+ default="video_path",
99
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
100
+ )
101
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
102
+ parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.")
103
+ parser.add_argument("--n_jobs", type=int, default=1, help="The number of concurrent processes.")
104
+
105
+ parser.add_argument(
106
+ "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)."
107
+ )
108
+ parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.")
109
+ parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.")
110
+ parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.")
111
+ parser.add_argument(
112
+ "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
113
+ )
114
+ parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.")
115
+ parser.add_argument(
116
+ "--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
117
+ )
118
+ parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.")
119
+ parser.add_argument(
120
+ "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)."
121
+ )
122
+ parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.")
123
+
124
+ args = parser.parse_args()
125
+ return args
126
+
127
+
128
+ def main():
129
+ args = parse_args()
130
+
131
+ if args.video_metadata_path.endswith(".csv"):
132
+ video_metadata_df = pd.read_csv(args.video_metadata_path)
133
+ elif args.video_metadata_path.endswith(".jsonl"):
134
+ video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
135
+ else:
136
+ raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
137
+ video_path_list = video_metadata_df[args.video_path_column].tolist()
138
+
139
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
140
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
141
+
142
+ if os.path.exists(args.saved_path):
143
+ if args.saved_path.endswith(".csv"):
144
+ saved_metadata_df = pd.read_csv(args.saved_path)
145
+ elif args.saved_path.endswith(".jsonl"):
146
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
147
+ saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
148
+ video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
149
+ logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
150
+
151
+ video_path_list = filter(
152
+ video_path_list,
153
+ basic_metadata_path=args.basic_metadata_path,
154
+ min_resolution=args.min_resolution,
155
+ min_duration=args.min_duration,
156
+ max_duration=args.max_duration,
157
+ asethetic_score_metadata_path=args.asethetic_score_metadata_path,
158
+ min_asethetic_score=args.min_asethetic_score,
159
+ asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path,
160
+ min_asethetic_score_siglip=args.min_asethetic_score_siglip,
161
+ text_score_metadata_path=args.text_score_metadata_path,
162
+ min_text_score=args.min_text_score,
163
+ )
164
+ video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
165
+ # Sorting to guarantee the same result for each process.
166
+ video_path_list = natsorted(video_path_list)
167
+
168
+ for i in tqdm(range(0, len(video_path_list), args.saved_freq)):
169
+ result_list = Parallel(n_jobs=args.n_jobs)(
170
+ delayed(compute_motion_score)(video_path) for video_path in tqdm(video_path_list[i: i + args.saved_freq])
171
+ )
172
+ result_list = [result for result in result_list if result is not None]
173
+ if len(result_list) == 0:
174
+ continue
175
+
176
+ result_df = pd.DataFrame(result_list)
177
+ if args.saved_path.endswith(".csv"):
178
+ header = False if os.path.exists(args.saved_path) else True
179
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
180
+ elif args.saved_path.endswith(".jsonl"):
181
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
182
+ logger.info(f"Save result to {args.saved_path}.")
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
cogvideox/video_caption/compute_text_score.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import easyocr
6
+ import numpy as np
7
+ import pandas as pd
8
+ from accelerate import PartialState
9
+ from accelerate.utils import gather_object
10
+ from natsort import natsorted
11
+ from tqdm import tqdm
12
+ from torchvision.datasets.utils import download_url
13
+
14
+ from utils.logger import logger
15
+ from utils.video_utils import extract_frames
16
+ from utils.filter import filter
17
+
18
+
19
+ def init_ocr_reader(root: str = "~/.cache/easyocr", device: str = "gpu"):
20
+ root = os.path.expanduser(root)
21
+ if not os.path.exists(root):
22
+ os.makedirs(root)
23
+ download_url(
24
+ "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/easyocr/craft_mlt_25k.pth",
25
+ root,
26
+ filename="craft_mlt_25k.pth",
27
+ md5="2f8227d2def4037cdb3b34389dcf9ec1",
28
+ )
29
+ ocr_reader = easyocr.Reader(
30
+ lang_list=["en", "ch_sim"],
31
+ gpu=device,
32
+ recognizer=False,
33
+ verbose=False,
34
+ model_storage_directory=root,
35
+ )
36
+
37
+ return ocr_reader
38
+
39
+
40
+ def triangle_area(p1, p2, p3):
41
+ """Compute the triangle area according to its coordinates.
42
+ """
43
+ x1, y1 = p1
44
+ x2, y2 = p2
45
+ x3, y3 = p3
46
+ tri_area = 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3)
47
+ return tri_area
48
+
49
+
50
+ def compute_text_score(video_path, ocr_reader):
51
+ _, images = extract_frames(video_path, sample_method="mid")
52
+ images = [np.array(image) for image in images]
53
+
54
+ frame_ocr_area_ratios = []
55
+ for image in images:
56
+ # horizontal detected results and free-form detected
57
+ horizontal_list, free_list = ocr_reader.detect(np.asarray(image))
58
+ width, height = image.shape[0], image.shape[1]
59
+
60
+ total_area = width * height
61
+ # rectangles
62
+ rect_area = 0
63
+ for xmin, xmax, ymin, ymax in horizontal_list[0]:
64
+ if xmax < xmin or ymax < ymin:
65
+ continue
66
+ rect_area += (xmax - xmin) * (ymax - ymin)
67
+ # free-form
68
+ quad_area = 0
69
+ try:
70
+ for points in free_list[0]:
71
+ triangle1 = points[:3]
72
+ quad_area += triangle_area(*triangle1)
73
+ triangle2 = points[3:] + [points[0]]
74
+ quad_area += triangle_area(*triangle2)
75
+ except:
76
+ quad_area = 0
77
+ text_area = rect_area + quad_area
78
+
79
+ frame_ocr_area_ratios.append(text_area / total_area)
80
+
81
+ video_meta_info = {
82
+ "video_path": Path(video_path).name,
83
+ "text_score": round(np.mean(frame_ocr_area_ratios), 5),
84
+ }
85
+
86
+ return video_meta_info
87
+
88
+
89
+ def parse_args():
90
+ parser = argparse.ArgumentParser(description="Compute the text score of the middle frame in the videos.")
91
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
92
+ parser.add_argument(
93
+ "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
94
+ )
95
+ parser.add_argument(
96
+ "--video_path_column",
97
+ type=str,
98
+ default="video_path",
99
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
100
+ )
101
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
102
+ parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.")
103
+
104
+ parser.add_argument(
105
+ "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)."
106
+ )
107
+ parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.")
108
+ parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.")
109
+ parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.")
110
+ parser.add_argument(
111
+ "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
112
+ )
113
+ parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.")
114
+ parser.add_argument(
115
+ "--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
116
+ )
117
+ parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.")
118
+ parser.add_argument(
119
+ "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)."
120
+ )
121
+ parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.")
122
+
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+
127
+ def main():
128
+ args = parse_args()
129
+
130
+ if args.video_metadata_path.endswith(".csv"):
131
+ video_metadata_df = pd.read_csv(args.video_metadata_path)
132
+ elif args.video_metadata_path.endswith(".jsonl"):
133
+ video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
134
+ else:
135
+ raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
136
+ video_path_list = video_metadata_df[args.video_path_column].tolist()
137
+
138
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
139
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
140
+
141
+ if os.path.exists(args.saved_path):
142
+ if args.saved_path.endswith(".csv"):
143
+ saved_metadata_df = pd.read_csv(args.saved_path)
144
+ elif args.saved_path.endswith(".jsonl"):
145
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
146
+ saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
147
+ video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
148
+ logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
149
+
150
+ video_path_list = filter(
151
+ video_path_list,
152
+ basic_metadata_path=args.basic_metadata_path,
153
+ min_resolution=args.min_resolution,
154
+ min_duration=args.min_duration,
155
+ max_duration=args.max_duration,
156
+ asethetic_score_metadata_path=args.asethetic_score_metadata_path,
157
+ min_asethetic_score=args.min_asethetic_score,
158
+ asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path,
159
+ min_asethetic_score_siglip=args.min_asethetic_score_siglip,
160
+ motion_score_metadata_path=args.motion_score_metadata_path,
161
+ min_motion_score=args.min_motion_score,
162
+ )
163
+ video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
164
+ # Sorting to guarantee the same result for each process.
165
+ video_path_list = natsorted(video_path_list)
166
+
167
+ state = PartialState()
168
+ if state.is_main_process:
169
+ # Check if the model is downloaded in the main process.
170
+ ocr_reader = init_ocr_reader(device="cpu")
171
+ state.wait_for_everyone()
172
+ ocr_reader = init_ocr_reader(device=state.device)
173
+
174
+ index = len(video_path_list) - len(video_path_list) % state.num_processes
175
+ # Avoid the NCCL timeout in the final gather operation.
176
+ logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to ensure each process handles the same number of videos.")
177
+ video_path_list = video_path_list[:index]
178
+ logger.info(f"{len(video_path_list)} videos are to be processed.")
179
+
180
+ result_list = []
181
+ with state.split_between_processes(video_path_list) as splitted_video_path_list:
182
+ for i, video_path in enumerate(tqdm(splitted_video_path_list)):
183
+ try:
184
+ video_meta_info = compute_text_score(video_path, ocr_reader)
185
+ result_list.append(video_meta_info)
186
+ except Exception as e:
187
+ logger.warning(f"Compute text score for video {video_path} with error: {e}.")
188
+ if i != 0 and i % args.saved_freq == 0:
189
+ state.wait_for_everyone()
190
+ gathered_result_list = gather_object(result_list)
191
+ if state.is_main_process and len(gathered_result_list) != 0:
192
+ result_df = pd.DataFrame(gathered_result_list)
193
+ if args.saved_path.endswith(".csv"):
194
+ header = False if os.path.exists(args.saved_path) else True
195
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
196
+ elif args.saved_path.endswith(".jsonl"):
197
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
198
+ logger.info(f"Save result to {args.saved_path}.")
199
+ result_list = []
200
+
201
+ state.wait_for_everyone()
202
+ gathered_result_list = gather_object(result_list)
203
+ if state.is_main_process and len(gathered_result_list) != 0:
204
+ result_df = pd.DataFrame(gathered_result_list)
205
+ if args.saved_path.endswith(".csv"):
206
+ header = False if os.path.exists(args.saved_path) else True
207
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
208
+ elif args.saved_path.endswith(".jsonl"):
209
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
210
+ logger.info(f"Save the final result to {args.saved_path}.")
211
+
212
+
213
+ if __name__ == "__main__":
214
+ main()
cogvideox/video_caption/compute_video_quality.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import pandas as pd
5
+ from accelerate import PartialState
6
+ from accelerate.utils import gather_object
7
+ from natsort import index_natsorted
8
+ from tqdm import tqdm
9
+ from torch.utils.data import DataLoader
10
+
11
+ import utils.image_evaluator as image_evaluator
12
+ import utils.video_evaluator as video_evaluator
13
+ from utils.logger import logger
14
+ from utils.video_dataset import VideoDataset, collate_fn
15
+
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Compute scores of uniform sampled frames from videos.")
19
+ parser.add_argument(
20
+ "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
21
+ )
22
+ parser.add_argument(
23
+ "--video_path_column",
24
+ type=str,
25
+ default="video_path",
26
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
27
+ )
28
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
29
+ parser.add_argument(
30
+ "--caption_column",
31
+ type=str,
32
+ default=None,
33
+ help="The column contains the caption.",
34
+ )
35
+ parser.add_argument(
36
+ "--frame_sample_method",
37
+ type=str,
38
+ choices=["mid", "uniform", "image"],
39
+ default="uniform",
40
+ )
41
+ parser.add_argument(
42
+ "--num_sampled_frames",
43
+ type=int,
44
+ default=8,
45
+ help="num_sampled_frames",
46
+ )
47
+ parser.add_argument("--metrics", nargs="+", type=str, required=True, help="The evaluation metric(s) for generated images.")
48
+ parser.add_argument(
49
+ "--batch_size",
50
+ type=int,
51
+ default=10,
52
+ required=False,
53
+ help="The batch size for the video dataset.",
54
+ )
55
+ parser.add_argument(
56
+ "--num_workers",
57
+ type=int,
58
+ default=4,
59
+ required=False,
60
+ help="The number of workers for the video dataset.",
61
+ )
62
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
63
+ parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")
64
+
65
+ args = parser.parse_args()
66
+ return args
67
+
68
+
69
+ def main():
70
+ args = parse_args()
71
+
72
+ if args.video_metadata_path.endswith(".csv"):
73
+ video_metadata_df = pd.read_csv(args.video_metadata_path)
74
+ elif args.video_metadata_path.endswith(".jsonl"):
75
+ video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
76
+ else:
77
+ raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
78
+
79
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
80
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
81
+
82
+ if os.path.exists(args.saved_path):
83
+ if args.saved_path.endswith(".csv"):
84
+ saved_metadata_df = pd.read_csv(args.saved_path)
85
+ elif args.saved_path.endswith(".jsonl"):
86
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
87
+
88
+ # Filter out the unprocessed video-caption pairs by setting the indicator=True.
89
+ merged_df = video_metadata_df.merge(saved_metadata_df, on="video_path", how="outer", indicator=True)
90
+ video_metadata_df = merged_df[merged_df["_merge"] == "left_only"]
91
+ # Sorting to guarantee the same result for each process.
92
+ video_metadata_df = video_metadata_df.iloc[index_natsorted(video_metadata_df["video_path"])].reset_index(drop=True)
93
+ if args.caption_column is None:
94
+ video_metadata_df = video_metadata_df[[args.video_path_column]]
95
+ else:
96
+ video_metadata_df = video_metadata_df[[args.video_path_column, args.caption_column + "_x"]]
97
+ video_metadata_df.rename(columns={args.caption_column + "_x": args.caption_column}, inplace=True)
98
+ logger.info(f"Resume from {args.saved_path}: {len(saved_metadata_df)} processed and {len(video_metadata_df)} to be processed.")
99
+
100
+ state = PartialState()
101
+ metric_fns = []
102
+ for metric in args.metrics:
103
+ if hasattr(image_evaluator, metric): # frame-wise
104
+ if state.is_main_process:
105
+ logger.info("Initializing frame-wise evaluator metrics...")
106
+ # Check if the model is downloaded in the main process.
107
+ getattr(image_evaluator, metric)(device="cpu")
108
+ state.wait_for_everyone()
109
+ metric_fns.append(getattr(image_evaluator, metric)(device=state.device))
110
+ else: # video-wise
111
+ if state.is_main_process:
112
+ logger.info("Initializing video-wise evaluator metrics...")
113
+ # Check if the model is downloaded in the main process.
114
+ getattr(video_evaluator, metric)(device="cpu")
115
+ state.wait_for_everyone()
116
+ metric_fns.append(getattr(video_evaluator, metric)(device=state.device))
117
+
118
+ result_dict = {args.video_path_column: [], "sample_frame_idx": []}
119
+ for metric in metric_fns:
120
+ result_dict[str(metric)] = []
121
+ if args.caption_column is not None:
122
+ result_dict[args.caption_column] = []
123
+
124
+ if args.frame_sample_method == "image":
125
+ logger.warning("Set args.num_sampled_frames to 1 since args.frame_sample_method is image.")
126
+ args.num_sampled_frames = 1
127
+
128
+ index = len(video_metadata_df) - len(video_metadata_df) % state.num_processes
129
+ # Avoid the NCCL timeout in the final gather operation.
130
+ logger.info(f"Drop {len(video_metadata_df) % state.num_processes} videos to ensure each process handles the same number of videos.")
131
+ video_metadata_df = video_metadata_df.iloc[:index]
132
+ logger.info(f"{len(video_metadata_df)} videos are to be processed.")
133
+
134
+ video_metadata_list = video_metadata_df.to_dict(orient='list')
135
+ with state.split_between_processes(video_metadata_list) as splitted_video_metadata:
136
+ video_dataset = VideoDataset(
137
+ dataset_inputs=splitted_video_metadata,
138
+ video_folder=args.video_folder,
139
+ text_column=args.caption_column,
140
+ sample_method=args.frame_sample_method,
141
+ num_sampled_frames=args.num_sampled_frames
142
+ )
143
+ video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
144
+
145
+ for idx, batch in enumerate(tqdm(video_loader)):
146
+ if len(batch) > 0:
147
+ batch_video_path = batch["path"]
148
+ result_dict["sample_frame_idx"].extend(batch["sampled_frame_idx"])
149
+ batch_frame = batch["sampled_frame"] # [batch_size, num_sampled_frames, H, W, C]
150
+ batch_caption = None
151
+ if args.caption_column is not None:
152
+ batch_caption = batch["text"]
153
+ result_dict["caption"].extend(batch_caption)
154
+ # Compute the quality.
155
+ for i, metric in enumerate(args.metrics):
156
+ quality_scores = metric_fns[i](batch_frame, batch_caption)
157
+ if isinstance(quality_scores[0], list): # frame-wise
158
+ quality_scores = [
159
+ [round(score, 5) for score in inner_list]
160
+ for inner_list in quality_scores
161
+ ]
162
+ else: # video-wise
163
+ quality_scores = [round(score, 5) for score in quality_scores]
164
+ result_dict[str(metric_fns[i])].extend(quality_scores)
165
+
166
+ if args.video_folder == "":
167
+ saved_video_path_list = batch_video_path
168
+ else:
169
+ saved_video_path_list = [os.path.relpath(video_path, args.video_folder) for video_path in batch_video_path]
170
+ result_dict[args.video_path_column].extend(saved_video_path_list)
171
+
172
+ # Save the metadata in the main process every saved_freq.
173
+ if (idx != 0) and (idx % args.saved_freq == 0):
174
+ state.wait_for_everyone()
175
+ gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
176
+ if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0:
177
+ result_df = pd.DataFrame(gathered_result_dict)
178
+ if args.saved_path.endswith(".csv"):
179
+ header = False if os.path.exists(args.saved_path) else True
180
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
181
+ elif args.saved_path.endswith(".jsonl"):
182
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
183
+ logger.info(f"Save result to {args.saved_path}.")
184
+ for k in result_dict.keys():
185
+ result_dict[k] = []
186
+
187
+ # Wait for all processes to finish and gather the final result.
188
+ state.wait_for_everyone()
189
+ gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
190
+ # Save the metadata in the main process.
191
+ if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0:
192
+ result_df = pd.DataFrame(gathered_result_dict)
193
+ if args.saved_path.endswith(".csv"):
194
+ header = False if os.path.exists(args.saved_path) else True
195
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
196
+ elif args.saved_path.endswith(".jsonl"):
197
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
198
+ logger.info(f"Save the final result to {args.saved_path}.")
199
+
200
+ if __name__ == "__main__":
201
+ main()
cogvideox/video_caption/cutscene_detect.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from multiprocessing import Pool
6
+
7
+ import pandas as pd
8
+ from scenedetect import open_video, SceneManager
9
+ from scenedetect.detectors import ContentDetector
10
+ from tqdm import tqdm
11
+
12
+ from utils.logger import logger
13
+
14
+
15
+ def cutscene_detection_star(args):
16
+ return cutscene_detection(*args)
17
+
18
+
19
+ def cutscene_detection(video_path, saved_path, cutscene_threshold=27, min_scene_len=15):
20
+ try:
21
+ if os.path.exists(saved_path):
22
+ logger.info(f"{video_path} has been processed.")
23
+ return
24
+ # Use PyAV as the backend to avoid (to some exent) containing the last frame of the previous scene.
25
+ # https://github.com/Breakthrough/PySceneDetect/issues/279#issuecomment-2152596761.
26
+ video = open_video(video_path, backend="pyav")
27
+ frame_rate, frame_size = video.frame_rate, video.frame_size
28
+ duration = deepcopy(video.duration)
29
+
30
+ frame_points, frame_timecode = [], {}
31
+ scene_manager = SceneManager()
32
+ scene_manager.add_detector(
33
+ # [ContentDetector, ThresholdDetector, AdaptiveDetector]
34
+ ContentDetector(threshold=cutscene_threshold, min_scene_len=min_scene_len)
35
+ )
36
+ scene_manager.detect_scenes(video, show_progress=False)
37
+ scene_list = scene_manager.get_scene_list()
38
+ for scene in scene_list:
39
+ for frame_time_code in scene:
40
+ frame_index = frame_time_code.get_frames()
41
+ if frame_index not in frame_points:
42
+ frame_points.append(frame_index)
43
+ frame_timecode[frame_index] = frame_time_code
44
+
45
+ del video, scene_manager
46
+
47
+ frame_points = sorted(frame_points)
48
+ output_scene_list = []
49
+ for idx in range(len(frame_points) - 1):
50
+ output_scene_list.append((frame_timecode[frame_points[idx]], frame_timecode[frame_points[idx+1]]))
51
+
52
+ timecode_list = [(frame_timecode_tuple[0].get_timecode(), frame_timecode_tuple[1].get_timecode()) for frame_timecode_tuple in output_scene_list]
53
+ meta_scene = [{
54
+ "video_path": Path(video_path).name,
55
+ "timecode_list": timecode_list,
56
+ "fram_rate": frame_rate,
57
+ "frame_size": frame_size,
58
+ "duration": str(duration) # __repr__
59
+ }]
60
+ pd.DataFrame(meta_scene).to_json(saved_path, orient="records", lines=True)
61
+ except Exception as e:
62
+ logger.warning(f"Cutscene detection with {video_path} failed. Error is: {e}.")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser(description="Cutscene Detection")
67
+ parser.add_argument(
68
+ "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
69
+ )
70
+ parser.add_argument(
71
+ "--video_path_column",
72
+ type=str,
73
+ default="video_path",
74
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
75
+ )
76
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
77
+ parser.add_argument("--saved_folder", type=str, required=True, help="The save path to the output results (csv/jsonl).")
78
+ parser.add_argument("--n_jobs", type=int, default=1, help="The number of processes.")
79
+
80
+ args = parser.parse_args()
81
+
82
+ metadata_df = pd.read_json(args.video_metadata_path, lines=True)
83
+ video_path_list = metadata_df[args.video_path_column].tolist()
84
+ video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
85
+
86
+ if not os.path.exists(args.saved_folder):
87
+ os.makedirs(args.saved_folder, exist_ok=True)
88
+ # The glob can be slow when there are many small jsonl files.
89
+ saved_path_list = [os.path.join(args.saved_folder, Path(video_path).stem + ".jsonl") for video_path in video_path_list]
90
+ args_list = [
91
+ (video_path, saved_path)
92
+ for video_path, saved_path in zip(video_path_list, saved_path_list)
93
+ ]
94
+ # Since the length of the video is not uniform, the gather operation is not performed.
95
+ # We need to run easyanimate/video_caption/utils/gather_jsonl.py after the program finised.
96
+ with Pool(args.n_jobs) as pool:
97
+ results = list(tqdm(pool.imap(cutscene_detection_star, args_list), total=len(video_path_list)))
cogvideox/video_caption/filter_meta_train.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import pandas as pd
5
+ from natsort import natsorted
6
+
7
+ from utils.logger import logger
8
+ from utils.filter import filter
9
+
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument(
14
+ "--caption_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
15
+ )
16
+ parser.add_argument(
17
+ "--video_path_column",
18
+ type=str,
19
+ default="video_path",
20
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
21
+ )
22
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
23
+ parser.add_argument(
24
+ "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)."
25
+ )
26
+ parser.add_argument("--min_resolution", type=float, default=720*1280, help="The resolution threshold.")
27
+ parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.")
28
+ parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.")
29
+ parser.add_argument(
30
+ "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
31
+ )
32
+ parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.")
33
+ parser.add_argument(
34
+ "--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality (SigLIP) metadata (csv/jsonl)."
35
+ )
36
+ parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.")
37
+ parser.add_argument(
38
+ "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)."
39
+ )
40
+ parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.")
41
+ parser.add_argument(
42
+ "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)."
43
+ )
44
+ parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.")
45
+ parser.add_argument(
46
+ "--videoclipxl_score_metadata_path", type=str, default=None, help="The path to the video-caption VideoCLIPXL score metadata (csv/jsonl)."
47
+ )
48
+ parser.add_argument("--min_videoclipxl_score", type=float, default=0.20, help="The VideoCLIPXL score threshold.")
49
+ parser.add_argument("--saved_path", type=str, required=True)
50
+
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+
55
+ def main():
56
+ args = parse_args()
57
+
58
+ raw_caption_df = pd.read_json(args.caption_metadata_path, lines=True)
59
+ video_path_list = raw_caption_df[args.video_path_column].to_list()
60
+ filtered_video_path_list = filter(
61
+ video_path_list,
62
+ basic_metadata_path=args.basic_metadata_path,
63
+ min_resolution=args.min_resolution,
64
+ min_duration=args.min_duration,
65
+ max_duration=args.max_duration,
66
+ asethetic_score_metadata_path=args.asethetic_score_metadata_path,
67
+ min_asethetic_score=args.min_asethetic_score,
68
+ asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path,
69
+ min_asethetic_score_siglip=args.min_asethetic_score_siglip,
70
+ text_score_metadata_path=args.text_score_metadata_path,
71
+ min_text_score=args.min_text_score,
72
+ motion_score_metadata_path=args.motion_score_metadata_path,
73
+ min_motion_score=args.min_motion_score,
74
+ videoclipxl_score_metadata_path=args.videoclipxl_score_metadata_path,
75
+ min_videoclipxl_score=args.min_videoclipxl_score,
76
+ video_path_column=args.video_path_column
77
+ )
78
+ filtered_video_path_list = natsorted(filtered_video_path_list)
79
+ filtered_caption_df = raw_caption_df[raw_caption_df[args.video_path_column].isin(filtered_video_path_list)]
80
+ train_df = filtered_caption_df.rename(columns={"video_path": "file_path", "caption": "text"})
81
+ train_df["file_path"] = train_df["file_path"].map(lambda x: os.path.join(args.video_folder, x))
82
+ train_df["type"] = "video"
83
+ train_df.to_json(args.saved_path, orient="records", force_ascii=False, indent=2)
84
+ logger.info(f"The final train file with {len(train_df)} videos are saved to {args.saved_path}.")
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
cogvideox/video_caption/package_patches/easyocr_detection_patched.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py.
2
+ 1. Disable DataParallel.
3
+ """
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ from torch.autograd import Variable
7
+ from PIL import Image
8
+ from collections import OrderedDict
9
+
10
+ import cv2
11
+ import numpy as np
12
+ from .craft_utils import getDetBoxes, adjustResultCoordinates
13
+ from .imgproc import resize_aspect_ratio, normalizeMeanVariance
14
+ from .craft import CRAFT
15
+
16
+ def copyStateDict(state_dict):
17
+ if list(state_dict.keys())[0].startswith("module"):
18
+ start_idx = 1
19
+ else:
20
+ start_idx = 0
21
+ new_state_dict = OrderedDict()
22
+ for k, v in state_dict.items():
23
+ name = ".".join(k.split(".")[start_idx:])
24
+ new_state_dict[name] = v
25
+ return new_state_dict
26
+
27
+ def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False):
28
+ if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays
29
+ image_arrs = image
30
+ else: # image is single numpy array
31
+ image_arrs = [image]
32
+
33
+ img_resized_list = []
34
+ # resize
35
+ for img in image_arrs:
36
+ img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size,
37
+ interpolation=cv2.INTER_LINEAR,
38
+ mag_ratio=mag_ratio)
39
+ img_resized_list.append(img_resized)
40
+ ratio_h = ratio_w = 1 / target_ratio
41
+ # preprocessing
42
+ x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1))
43
+ for n_img in img_resized_list]
44
+ x = torch.from_numpy(np.array(x))
45
+ x = x.to(device)
46
+
47
+ # forward pass
48
+ with torch.no_grad():
49
+ y, feature = net(x)
50
+
51
+ boxes_list, polys_list = [], []
52
+ for out in y:
53
+ # make score and link map
54
+ score_text = out[:, :, 0].cpu().data.numpy()
55
+ score_link = out[:, :, 1].cpu().data.numpy()
56
+
57
+ # Post-processing
58
+ boxes, polys, mapper = getDetBoxes(
59
+ score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
60
+
61
+ # coordinate adjustment
62
+ boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
63
+ polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
64
+ if estimate_num_chars:
65
+ boxes = list(boxes)
66
+ polys = list(polys)
67
+ for k in range(len(polys)):
68
+ if estimate_num_chars:
69
+ boxes[k] = (boxes[k], mapper[k])
70
+ if polys[k] is None:
71
+ polys[k] = boxes[k]
72
+ boxes_list.append(boxes)
73
+ polys_list.append(polys)
74
+
75
+ return boxes_list, polys_list
76
+
77
+ def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
78
+ net = CRAFT()
79
+
80
+ if device == 'cpu':
81
+ net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
82
+ if quantize:
83
+ try:
84
+ torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True)
85
+ except:
86
+ pass
87
+ else:
88
+ net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
89
+ # net = torch.nn.DataParallel(net).to(device)
90
+ net = net.to(device)
91
+ cudnn.benchmark = cudnn_benchmark
92
+
93
+ net.eval()
94
+ return net
95
+
96
+ def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs):
97
+ result = []
98
+ estimate_num_chars = optimal_num_chars is not None
99
+ bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector,
100
+ image, text_threshold,
101
+ link_threshold, low_text, poly,
102
+ device, estimate_num_chars)
103
+ if estimate_num_chars:
104
+ polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]
105
+ for polys in polys_list]
106
+
107
+ for polys in polys_list:
108
+ single_img_result = []
109
+ for i, box in enumerate(polys):
110
+ poly = np.array(box).astype(np.int32).reshape((-1))
111
+ single_img_result.append(poly)
112
+ result.append(single_img_result)
113
+
114
+ return result
cogvideox/video_caption/package_patches/vila_siglip_encoder_patched.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/NVlabs/VILA/blob/1c88211/llava/model/multimodal_encoder/siglip_encoder.py
2
+ # 1. Support transformers >= 4.36.2.
3
+ import torch
4
+ import transformers
5
+ from packaging import version
6
+ from transformers import AutoConfig, AutoModel, PretrainedConfig
7
+
8
+ from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2
9
+
10
+ if version.parse(transformers.__version__) > version.parse("4.36.2"):
11
+ from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
12
+ else:
13
+ from .siglip import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
14
+
15
+
16
+ class SiglipVisionTower(VisionTower):
17
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig, state_dict=None):
18
+ super().__init__(model_name_or_path, config)
19
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
20
+ self.vision_tower = SiglipVisionModel.from_pretrained(
21
+ # TODO(ligeng): why pass config here leading to errors?
22
+ model_name_or_path, torch_dtype=eval(config.model_dtype), state_dict=state_dict
23
+ )
24
+ self.is_loaded = True
25
+
26
+
27
+ class SiglipVisionTowerS2(VisionTowerS2):
28
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
29
+ super().__init__(model_name_or_path, config)
30
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
31
+ self.vision_tower = SiglipVisionModel.from_pretrained(
32
+ model_name_or_path, torch_dtype=eval(config.model_dtype)
33
+ )
34
+
35
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
36
+ self.image_processor.size['height'] = self.image_processor.size['width'] = self.scales[-1]
37
+
38
+ self.is_loaded = True
39
+
40
+ if version.parse(transformers.__version__) <= version.parse("4.36.2"):
41
+ AutoConfig.register("siglip_vision_model", SiglipVisionConfig)
42
+ AutoModel.register(SiglipVisionConfig, SiglipVisionModel)
cogvideox/video_caption/prompt/beautiful_prompt.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ I will upload some brief prompt words to be used for AI-generated videos. Please expand these brief prompt words into a more detailed description to enhance the quality of the generated videos. The detailed description should include the main subject (person, object, animal, or none) actions and their attributes or status sequence, the background (the objects, location, weather, and time), the view shot and camera movement.
2
+ The final detailed description must not exceed 200 words. Output with the following json format:
3
+ {"detailed description": "your detailed description here"}
4
+
5
+ Here is an example:
6
+ brief prompt words: "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."
7
+ {"detailed description": "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."}
8
+
9
+ Here are the brief prompt words:
cogvideox/video_caption/prompt/rewrite.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Please rewrite the video description to be useful for AI to re-generate the video, according to the following requirements
2
+ 1. Do not start with something similar to 'The video/scene/frame shows' or "In this video/scene/frame".
3
+ 2. Remove the subjective content deviates from describing the visual content of the video. For instance, a sentence like "It gives a feeling of ease and tranquility and makes people feel comfortable" is considered subjective.
4
+ 3. Remove the non-existent description that does not in the visual content of the video, For instance, a sentence like "There is no visible detail that could be used to identify the individual beyond what is shown." is considered as the non-existent description.
5
+ 4. Here are some examples of good descriptions: 1) A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2) A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.
6
+ 5. Output with the following json format:
7
+ {"rewritten description": "your rewritten description here"}
8
+
9
+ Here is the video description:
cogvideox/video_caption/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pandas>=2.0.0
2
+ easyocr==1.7.1
3
+ git+https://github.com/openai/CLIP.git
4
+ natsort
5
+ joblib
6
+ scenedetect
7
+ av
8
+ # https://github.com/NVlabs/VILA/issues/78#issuecomment-2195568292
9
+ numpy<2.0.0
cogvideox/video_caption/scripts/stage_1_video_splitting.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VIDEO_FOLDER="datasets/panda_70m/videos/data/"
2
+ META_FILE_PATH="datasets/panda_70m/videos/meta_file_info.jsonl"
3
+ SCENE_FOLDER="datasets/panda_70m/videos/meta_scene_info/"
4
+ SCENE_SAVED_PATH="datasets/panda_70m/videos/meta_scene_info.jsonl"
5
+ OUTPUT_FOLDER="datasets/panda_70m/videos_clips/data/"
6
+ RESOLUTION_THRESHOLD=$((512*512))
7
+
8
+ # Set the duration range of video clips.
9
+ export MIN_SECONDS=3
10
+ export MAX_SECONDS=10
11
+
12
+ # Save all video names in a video folder as a meta file.
13
+ python -m utils.get_meta_file \
14
+ --video_folder $VIDEO_FOLDER \
15
+ --saved_path $META_FILE_PATH
16
+
17
+ # Perform scene detection on the video dataset.
18
+ # Adjust the n_jobs parameter based on the actual number of CPU cores in the machine.
19
+ python cutscene_detect.py \
20
+ --video_metadata_path $META_FILE_PATH \
21
+ --video_folder $VIDEO_FOLDER \
22
+ --saved_folder $SCENE_FOLDER \
23
+ --n_jobs 32
24
+
25
+ # Gather all scene jsonl files to a single scene jsonl file.
26
+ # Adjust the n_jobs parameter based on the actual I/O speed in the machine.
27
+ python -m utils.gather_jsonl \
28
+ --meta_folder $SCENE_FOLDER \
29
+ --meta_file_path $SCENE_SAVED_PATH \
30
+ --n_jobs 64
31
+
32
+ # Perform video splitting filtered by the RESOLUTION_THRESHOLD.
33
+ # It consumes more CPU computing resources compared to the above operations.
34
+ python video_splitting.py \
35
+ --video_metadata_path $SCENE_SAVED_PATH \
36
+ --video_folder $VIDEO_FOLDER \
37
+ --output_folder $OUTPUT_FOLDER \
38
+ --n_jobs 16 \
39
+ --resolution_threshold $RESOLUTION_THRESHOLD
cogvideox/video_caption/scripts/stage_2_video_filtering.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ META_FILE_PATH="datasets/panda_70m/videos_clips/data/meta_file_info.jsonl"
2
+ VIDEO_FOLDER="datasets/panda_70m/videos_clips/data/"
3
+ VIDEO_QUALITY_SAVED_PATH="datasets/panda_70m/videos_clips/meta_quality_info_siglip.jsonl"
4
+ MIN_ASETHETIC_SCORE_SIGLIP=4.0
5
+ TEXT_SAVED_PATH="datasets/panda_70m/videos_clips/meta_text_info.jsonl"
6
+ MIN_TEXT_SCORE=0.02
7
+ MOTION_SAVED_PATH="datasets/panda_70m/videos_clips/meta_motion_info.jsonl"
8
+
9
+ python -m utils.get_meta_file \
10
+ --video_folder $VIDEO_FOLDER \
11
+ --saved_path $META_FILE_PATH
12
+
13
+ # Get the asethetic score (SigLIP) of all videos
14
+ accelerate launch compute_video_quality.py \
15
+ --video_metadata_path $META_FILE_PATH \
16
+ --video_folder $VIDEO_FOLDER \
17
+ --metrics "AestheticScoreSigLIP" \
18
+ --frame_sample_method uniform \
19
+ --num_sampled_frames 4 \
20
+ --saved_freq 10 \
21
+ --saved_path $VIDEO_QUALITY_SAVED_PATH \
22
+ --batch_size 4
23
+
24
+ # Get the text score of all videos filtered by the video quality score.
25
+ accelerate launch compute_text_score.py \
26
+ --video_metadata_path $META_FILE_PATH \
27
+ --video_folder $VIDEO_FOLDER \
28
+ --saved_freq 10 \
29
+ --saved_path $TEXT_SAVED_PATH \
30
+ --asethetic_score_siglip_metadata_path $VIDEO_QUALITY_SAVED_PATH \
31
+ --min_asethetic_score_siglip $MIN_ASETHETIC_SCORE_SIGLIP
32
+
33
+ # Get the motion score of all videos filtered by the video quality score and text score.
34
+ python compute_motion_score.py \
35
+ --video_metadata_path $META_FILE_PATH \
36
+ --video_folder $VIDEO_FOLDER \
37
+ --saved_freq 10 \
38
+ --saved_path $MOTION_SAVED_PATH \
39
+ --n_jobs 8 \
40
+ --text_score_metadata_path $TEXT_SAVED_PATH \
41
+ --min_text_score $MIN_TEXT_SCORE
cogvideox/video_caption/scripts/stage_3_video_recaptioning.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ META_FILE_PATH="datasets/panda_70m/videos_clips/data/meta_file_info.jsonl"
2
+ VIDEO_FOLDER="datasets/panda_70m/videos_clips/data/"
3
+ MOTION_SAVED_PATH="datasets/panda_70m/videos_clips/meta_motion_info.jsonl"
4
+ MIN_MOTION_SCORE=2
5
+ VIDEO_CAPTION_SAVED_PATH="datasets/panda_70m/meta_caption_info_vila_8b.jsonl"
6
+ REWRITTEN_VIDEO_CAPTION_SAVED_PATH="datasets/panda_70m/meta_caption_info_vila_8b_rewritten.jsonl"
7
+ VIDEOCLIPXL_SCORE_SAVED_PATH="datasets/panda_70m/meta_caption_info_vila_8b_rewritten_videoclipxl.jsonl"
8
+ MIN_VIDEOCLIPXL_SCORE=0.20
9
+ TRAIN_SAVED_PATH="datasets/panda_70m/train_panda_70m.json"
10
+ # Manually download Efficient-Large-Model/Llama-3-VILA1.5-8b-AWQ to VILA_MODEL_PATH.
11
+ # Manually download meta-llama/Meta-Llama-3-8B-Instruct to REWRITE_MODEL_PATH.
12
+
13
+ # Use VILA1.5-AWQ to perform recaptioning.
14
+ accelerate launch vila_video_recaptioning.py \
15
+ --video_metadata_path ${META_FILE_PATH} \
16
+ --video_folder ${VIDEO_FOLDER} \
17
+ --model_path ${VILA_MODEL_PATH} \
18
+ --precision "W4A16" \
19
+ --saved_path $VIDEO_CAPTION_SAVED_PATH \
20
+ --saved_freq 1 \
21
+ --motion_score_metadata_path $MOTION_SAVED_PATH \
22
+ --min_motion_score $MIN_MOTION_SCORE
23
+
24
+ # Rewrite video captions (optional).
25
+ python caption_rewrite.py \
26
+ --video_metadata_path $VIDEO_CAPTION_SAVED_PATH \
27
+ --batch_size 4096 \
28
+ --model_name $REWRITE_MODEL_PATH \
29
+ --prompt prompt/rewrite.txt \
30
+ --prefix '"rewritten description": ' \
31
+ --saved_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \
32
+ --saved_freq 1
33
+
34
+ # Compute caption-video alignment (optional).
35
+ accelerate launch compute_video_quality.py \
36
+ --video_metadata_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \
37
+ --caption_column caption \
38
+ --video_folder $VIDEO_FOLDER \
39
+ --frame_sample_method uniform \
40
+ --num_sampled_frames 8 \
41
+ --metrics VideoCLIPXLScore \
42
+ --batch_size 4 \
43
+ --saved_path $VIDEOCLIPXL_SCORE_SAVED_PATH \
44
+ --saved_freq 10
45
+
46
+ # Get the final train file.
47
+ python filter_meta_train.py \
48
+ --caption_metadata_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \
49
+ --video_folder=$VIDEO_FOLDER \
50
+ --videoclipxl_score_metadata_path $VIDEOCLIPXL_SCORE_SAVED_PATH \
51
+ --min_videoclipxl_score $MIN_VIDEOCLIPXL_SCORE \
52
+ --saved_path=$TRAIN_SAVED_PATH
cogvideox/video_caption/utils/filter.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+
4
+ import pandas as pd
5
+
6
+ from .logger import logger
7
+
8
+
9
+ def filter(
10
+ video_path_list,
11
+ basic_metadata_path=None,
12
+ min_resolution=0,
13
+ min_duration=-1,
14
+ max_duration=-1,
15
+ asethetic_score_metadata_path=None,
16
+ min_asethetic_score=4,
17
+ asethetic_score_siglip_metadata_path=None,
18
+ min_asethetic_score_siglip=4,
19
+ text_score_metadata_path=None,
20
+ min_text_score=0.02,
21
+ motion_score_metadata_path=None,
22
+ min_motion_score=2,
23
+ videoclipxl_score_metadata_path=None,
24
+ min_videoclipxl_score=0.20,
25
+ video_path_column="video_path",
26
+ ):
27
+ video_path_list = [os.path.basename(video_path) for video_path in video_path_list]
28
+
29
+ if basic_metadata_path is not None:
30
+ if basic_metadata_path.endswith(".csv"):
31
+ basic_df = pd.read_csv(basic_metadata_path)
32
+ elif basic_metadata_path.endswith(".jsonl"):
33
+ basic_df = pd.read_json(basic_metadata_path, lines=True)
34
+
35
+ basic_df["resolution"] = basic_df["frame_size"].apply(lambda x: x[0] * x[1])
36
+ filtered_basic_df = basic_df[basic_df["resolution"] < min_resolution]
37
+ filtered_video_path_list = filtered_basic_df[video_path_column].tolist()
38
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
39
+
40
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
41
+ logger.info(
42
+ f"Load {basic_metadata_path} ({len(basic_df)}) and filter {len(filtered_video_path_list)} videos "
43
+ f"with resolution less than {min_resolution}."
44
+ )
45
+
46
+ if min_duration != -1:
47
+ filtered_basic_df = basic_df[basic_df["duration"] < min_duration]
48
+ filtered_video_path_list = filtered_basic_df[video_path_column].tolist()
49
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
50
+
51
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
52
+ logger.info(
53
+ f"Load {basic_metadata_path} and filter {len(filtered_video_path_list)} videos "
54
+ f"with duration less than {min_duration}."
55
+ )
56
+
57
+ if max_duration != -1:
58
+ filtered_basic_df = basic_df[basic_df["duration"] > max_duration]
59
+ filtered_video_path_list = filtered_basic_df[video_path_column].tolist()
60
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
61
+
62
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
63
+ logger.info(
64
+ f"Load {basic_metadata_path} and filter {len(filtered_video_path_list)} videos "
65
+ f"with duration greater than {max_duration}."
66
+ )
67
+
68
+ if asethetic_score_metadata_path is not None:
69
+ if asethetic_score_metadata_path.endswith(".csv"):
70
+ asethetic_score_df = pd.read_csv(asethetic_score_metadata_path)
71
+ elif asethetic_score_metadata_path.endswith(".jsonl"):
72
+ asethetic_score_df = pd.read_json(asethetic_score_metadata_path, lines=True)
73
+
74
+ # In pandas, csv will save lists as strings, whereas jsonl will not.
75
+ asethetic_score_df["aesthetic_score"] = asethetic_score_df["aesthetic_score"].apply(
76
+ lambda x: ast.literal_eval(x) if isinstance(x, str) else x
77
+ )
78
+ asethetic_score_df["aesthetic_score_mean"] = asethetic_score_df["aesthetic_score"].apply(lambda x: sum(x) / len(x))
79
+ filtered_asethetic_score_df = asethetic_score_df[asethetic_score_df["aesthetic_score_mean"] < min_asethetic_score]
80
+ filtered_video_path_list = filtered_asethetic_score_df[video_path_column].tolist()
81
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
82
+
83
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
84
+ logger.info(
85
+ f"Load {asethetic_score_metadata_path} ({len(asethetic_score_df)}) and filter {len(filtered_video_path_list)} videos "
86
+ f"with aesthetic score less than {min_asethetic_score}."
87
+ )
88
+
89
+ if asethetic_score_siglip_metadata_path is not None:
90
+ if asethetic_score_siglip_metadata_path.endswith(".csv"):
91
+ asethetic_score_siglip_df = pd.read_csv(asethetic_score_siglip_metadata_path)
92
+ elif asethetic_score_siglip_metadata_path.endswith(".jsonl"):
93
+ asethetic_score_siglip_df = pd.read_json(asethetic_score_siglip_metadata_path, lines=True)
94
+
95
+ # In pandas, csv will save lists as strings, whereas jsonl will not.
96
+ asethetic_score_siglip_df["aesthetic_score_siglip"] = asethetic_score_siglip_df["aesthetic_score_siglip"].apply(
97
+ lambda x: ast.literal_eval(x) if isinstance(x, str) else x
98
+ )
99
+ asethetic_score_siglip_df["aesthetic_score_siglip_mean"] = asethetic_score_siglip_df["aesthetic_score_siglip"].apply(
100
+ lambda x: sum(x) / len(x)
101
+ )
102
+ filtered_asethetic_score_siglip_df = asethetic_score_siglip_df[
103
+ asethetic_score_siglip_df["aesthetic_score_siglip_mean"] < min_asethetic_score_siglip
104
+ ]
105
+ filtered_video_path_list = filtered_asethetic_score_siglip_df[video_path_column].tolist()
106
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
107
+
108
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
109
+ logger.info(
110
+ f"Load {asethetic_score_siglip_metadata_path} ({len(asethetic_score_siglip_df)}) and filter {len(filtered_video_path_list)} videos "
111
+ f"with aesthetic score (SigLIP) less than {min_asethetic_score_siglip}."
112
+ )
113
+
114
+ if text_score_metadata_path is not None:
115
+ if text_score_metadata_path.endswith(".csv"):
116
+ text_score_df = pd.read_csv(text_score_metadata_path)
117
+ elif text_score_metadata_path.endswith(".jsonl"):
118
+ text_score_df = pd.read_json(text_score_metadata_path, lines=True)
119
+
120
+ filtered_text_score_df = text_score_df[text_score_df["text_score"] > min_text_score]
121
+ filtered_video_path_list = filtered_text_score_df[video_path_column].tolist()
122
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
123
+
124
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
125
+ logger.info(
126
+ f"Load {text_score_metadata_path} ({len(text_score_df)}) and filter {len(filtered_video_path_list)} videos "
127
+ f"with text score greater than {min_text_score}."
128
+ )
129
+
130
+ if motion_score_metadata_path is not None:
131
+ if motion_score_metadata_path.endswith(".csv"):
132
+ motion_score_df = pd.read_csv(motion_score_metadata_path)
133
+ elif motion_score_metadata_path.endswith(".jsonl"):
134
+ motion_score_df = pd.read_json(motion_score_metadata_path, lines=True)
135
+
136
+ filtered_motion_score_df = motion_score_df[motion_score_df["motion_score"] < min_motion_score]
137
+ filtered_video_path_list = filtered_motion_score_df[video_path_column].tolist()
138
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
139
+
140
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
141
+ logger.info(
142
+ f"Load {motion_score_metadata_path} ({len(motion_score_df)}) and filter {len(filtered_video_path_list)} videos "
143
+ f"with motion score smaller than {min_motion_score}."
144
+ )
145
+
146
+ if videoclipxl_score_metadata_path is not None:
147
+ if videoclipxl_score_metadata_path.endswith(".csv"):
148
+ videoclipxl_score_df = pd.read_csv(videoclipxl_score_metadata_path)
149
+ elif videoclipxl_score_metadata_path.endswith(".jsonl"):
150
+ videoclipxl_score_df = pd.read_json(videoclipxl_score_metadata_path, lines=True)
151
+
152
+ filtered_videoclipxl_score_df = videoclipxl_score_df[videoclipxl_score_df["videoclipxl_score"] < min_videoclipxl_score]
153
+ filtered_video_path_list = filtered_videoclipxl_score_df[video_path_column].tolist()
154
+ filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
155
+
156
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
157
+ logger.info(
158
+ f"Load {videoclipxl_score_metadata_path} ({len(videoclipxl_score_df)}) and "
159
+ f"filter {len(filtered_video_path_list)} videos with mixclip score smaller than {min_videoclipxl_score}."
160
+ )
161
+
162
+ return video_path_list
cogvideox/video_caption/utils/gather_jsonl.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import glob
4
+ import json
5
+ from multiprocessing import Pool, Manager
6
+
7
+ import pandas as pd
8
+ from natsort import index_natsorted
9
+
10
+ from .logger import logger
11
+
12
+
13
+ def process_file(file_path, shared_list):
14
+ with open(file_path, "r") as f:
15
+ for line in f:
16
+ data = json.loads(line)
17
+ shared_list.append(data)
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description="Gather all jsonl files in a folder (meta_folder) to a single jsonl file (meta_file_path).")
22
+ parser.add_argument("--meta_folder", type=str, required=True)
23
+ parser.add_argument("--meta_file_path", type=str, required=True)
24
+ parser.add_argument("--video_path_column", type=str, default="video_path")
25
+ parser.add_argument("--n_jobs", type=int, default=1)
26
+
27
+ args = parser.parse_args()
28
+ return args
29
+
30
+
31
+ def main():
32
+ args = parse_args()
33
+
34
+ jsonl_files = glob.glob(os.path.join(args.meta_folder, "*.jsonl"))
35
+
36
+ with Manager() as manager:
37
+ shared_list = manager.list()
38
+ with Pool(processes=args.n_jobs) as pool:
39
+ for file_path in jsonl_files:
40
+ pool.apply_async(process_file, args=(file_path, shared_list))
41
+ pool.close()
42
+ pool.join()
43
+
44
+ with open(args.meta_file_path, "w") as f:
45
+ for item in shared_list:
46
+ f.write(json.dumps(item) + '\n')
47
+
48
+ df = pd.read_json(args.meta_file_path, lines=True)
49
+ df = df.iloc[index_natsorted(df[args.video_path_column])].reset_index(drop=True)
50
+ logger.info(f"Save the gathered single jsonl file to {args.meta_file_path}.")
51
+ df.to_json(args.meta_file_path, orient="records", lines=True, force_ascii=False)
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
cogvideox/video_caption/utils/get_meta_file.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ from natsort import natsorted
6
+ from tqdm import tqdm
7
+
8
+ from .logger import logger
9
+
10
+
11
+ ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"])
12
+ ALL_IMGAE_EXT = set(["png", "webp", "jpg", "jpeg", "bmp", "gif"])
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description="Compute scores of uniform sampled frames from videos.")
17
+ parser.add_argument(
18
+ "--image_path_column",
19
+ type=str,
20
+ default="image_path",
21
+ help="The column contains the image path (an absolute path or a relative path w.r.t the image_folder).",
22
+ )
23
+ parser.add_argument("--image_folder", type=str, default=None, help="The video folder.")
24
+ parser.add_argument(
25
+ "--video_path_column",
26
+ type=str,
27
+ default="video_path",
28
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
29
+ )
30
+ parser.add_argument("--video_folder", type=str, default=None, help="The video folder.")
31
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
32
+ parser.add_argument("--recursive", action="store_true", help="Whether to search sub-folders recursively.")
33
+
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+
38
+ def main():
39
+ args = parse_args()
40
+
41
+ if args.video_folder is None and args.image_folder is None:
42
+ raise ValueError("Either video_folder or image_folder should be specified in the arguments.")
43
+ if args.video_folder is not None and args.image_folder is not None:
44
+ raise ValueError("Both video_folder and image_folder can not be specified in the arguments at the same time.")
45
+
46
+ # Use the path name instead of the file name as video_path/image_path (unique ID).
47
+ if args.video_folder is not None:
48
+ video_path_list = []
49
+ video_folder = Path(args.video_folder)
50
+ for ext in tqdm(list(ALL_VIDEO_EXT)):
51
+ if args.recursive:
52
+ video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.rglob(f"*.{ext}")]
53
+ else:
54
+ video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.glob(f"*.{ext}")]
55
+ video_path_list = natsorted(video_path_list)
56
+ meta_file_df = pd.DataFrame({args.video_path_column: video_path_list})
57
+
58
+ if args.image_folder is not None:
59
+ image_path_list = []
60
+ image_folder = Path(args.image_folder)
61
+ for ext in tqdm(list(ALL_IMGAE_EXT)):
62
+ if args.recursive:
63
+ image_path_list += [str(file.relative_to(image_folder)) for file in image_folder.rglob(f"*.{ext}")]
64
+ else:
65
+ image_path_list += [str(file.relative_to(image_folder)) for file in image_folder.glob(f"*.{ext}")]
66
+ image_path_list = natsorted(image_path_list)
67
+ meta_file_df = pd.DataFrame({args.image_path_column: image_path_list})
68
+
69
+ logger.info(f"{len(meta_file_df)} files in total. Save the result to {args.saved_path}.")
70
+ meta_file_df.to_json(args.saved_path, orient="records", lines=True)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
cogvideox/video_caption/utils/image_evaluator.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
+ import clip
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ from torchvision.datasets.utils import download_url
10
+ from transformers import AutoModel, AutoProcessor
11
+
12
+ from .siglip_v2_5 import convert_v2_5_from_siglip
13
+
14
+ # All metrics.
15
+ __all__ = ["AestheticScore", "AestheticScoreSigLIP", "CLIPScore"]
16
+
17
+ _MODELS = {
18
+ "CLIP_ViT-L/14": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViT-L-14.pt",
19
+ "Aesthetics_V2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/sac%2Blogos%2Bava1-l14-linearMSE.pth",
20
+ "aesthetic_predictor_v2_5": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/aesthetic_predictor_v2_5.pth",
21
+ }
22
+ _MD5 = {
23
+ "CLIP_ViT-L/14": "096db1af569b284eb76b3881534822d9",
24
+ "Aesthetics_V2": "b1047fd767a00134b8fd6529bf19521a",
25
+ "aesthetic_predictor_v2_5": "c46eb8c29f714c9231dc630b8226842a",
26
+ }
27
+
28
+
29
+ def get_list_depth(lst):
30
+ if isinstance(lst, list):
31
+ return 1 + max(get_list_depth(item) for item in lst)
32
+ else:
33
+ return 0
34
+
35
+
36
+ def reshape_images(images: Union[list[list[Image.Image]], list[Image.Image]]):
37
+ # Check the input sanity.
38
+ depth = get_list_depth(images)
39
+ if depth == 1: # batch image input
40
+ if not isinstance(images[0], Image.Image):
41
+ raise ValueError("The item in 1D images should be Image.Image.")
42
+ num_sampled_frames = None
43
+ elif depth == 2: # batch video input
44
+ if not isinstance(images[0][0], Image.Image):
45
+ raise ValueError("The item in 2D images (videos) should be Image.Image.")
46
+ num_sampled_frames = len(images[0])
47
+ if not all(len(video_frames) == num_sampled_frames for video_frames in images):
48
+ raise ValueError("All item in 2D images should be with the same length.")
49
+ # [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C].
50
+ reshaped_images = []
51
+ for video_frames in images:
52
+ reshaped_images.extend([frame for frame in video_frames])
53
+ images = reshaped_images
54
+ else:
55
+ raise ValueError("The input images should be in 1/2D list.")
56
+
57
+ return images, num_sampled_frames
58
+
59
+
60
+ def reshape_scores(scores: list[float], num_sampled_frames: int) -> list[float]:
61
+ if isinstance(scores, list):
62
+ if num_sampled_frames is not None: # Batch video input
63
+ batch_size = len(scores) // num_sampled_frames
64
+ scores = [
65
+ scores[i * num_sampled_frames:(i + 1) * num_sampled_frames]
66
+ for i in range(batch_size)
67
+ ]
68
+ return scores
69
+ else:
70
+ return [scores]
71
+
72
+
73
+ # if you changed the MLP architecture during training, change it also here:
74
+ class _MLP(nn.Module):
75
+ def __init__(self, input_size):
76
+ super().__init__()
77
+ self.input_size = input_size
78
+ self.layers = nn.Sequential(
79
+ nn.Linear(self.input_size, 1024),
80
+ # nn.ReLU(),
81
+ nn.Dropout(0.2),
82
+ nn.Linear(1024, 128),
83
+ # nn.ReLU(),
84
+ nn.Dropout(0.2),
85
+ nn.Linear(128, 64),
86
+ # nn.ReLU(),
87
+ nn.Dropout(0.1),
88
+ nn.Linear(64, 16),
89
+ # nn.ReLU(),
90
+ nn.Linear(16, 1),
91
+ )
92
+
93
+ def forward(self, x):
94
+ return self.layers(x)
95
+
96
+
97
+ class AestheticScore:
98
+ """Compute LAION Aesthetics Score V2 based on openai/clip. Note that the default
99
+ inference dtype with GPUs is fp16 in openai/clip.
100
+
101
+ Ref:
102
+ 1. https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py.
103
+ 2. https://github.com/openai/CLIP/issues/30.
104
+ """
105
+
106
+ def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"):
107
+ # The CLIP model is loaded in the evaluation mode.
108
+ self.root = os.path.expanduser(root)
109
+ if not os.path.exists(self.root):
110
+ os.makedirs(self.root)
111
+ filename = "ViT-L-14.pt"
112
+ download_url(_MODELS["CLIP_ViT-L/14"], self.root, filename=filename, md5=_MD5["CLIP_ViT-L/14"])
113
+ self.clip_model, self.preprocess = clip.load(os.path.join(self.root, filename), device=device)
114
+ self.device = device
115
+ self._load_mlp()
116
+
117
+ def _load_mlp(self):
118
+ filename = "sac+logos+ava1-l14-linearMSE.pth"
119
+ download_url(_MODELS["Aesthetics_V2"], self.root, filename=filename, md5=_MD5["Aesthetics_V2"])
120
+ state_dict = torch.load(os.path.join(self.root, filename))
121
+ self.mlp = _MLP(768)
122
+ self.mlp.load_state_dict(state_dict)
123
+ self.mlp.to(self.device)
124
+ self.mlp.eval()
125
+
126
+ def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts=None) -> list[float]:
127
+ images, num_sampled_frames = reshape_images(images)
128
+
129
+ with torch.no_grad():
130
+ images = torch.stack([self.preprocess(image) for image in images]).to(self.device)
131
+ image_embs = F.normalize(self.clip_model.encode_image(images))
132
+ scores = self.mlp(image_embs.float()) # torch.float16 -> torch.float32, [N, 1]
133
+
134
+ scores = scores.squeeze().tolist() # scalar or list
135
+ return reshape_scores(scores, num_sampled_frames)
136
+
137
+ def __repr__(self) -> str:
138
+ return "aesthetic_score"
139
+
140
+
141
+ class AestheticScoreSigLIP:
142
+ """Compute Aesthetics Score V2.5 based on google/siglip-so400m-patch14-384.
143
+
144
+ Ref:
145
+ 1. https://github.com/discus0434/aesthetic-predictor-v2-5.
146
+ 2. https://github.com/discus0434/aesthetic-predictor-v2-5/issues/2.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ root: str = "~/.cache/clip",
152
+ device: str = "cpu",
153
+ torch_dtype=torch.float16
154
+ ):
155
+ self.root = os.path.expanduser(root)
156
+ if not os.path.exists(self.root):
157
+ os.makedirs(self.root)
158
+ filename = "aesthetic_predictor_v2_5.pth"
159
+ download_url(_MODELS["aesthetic_predictor_v2_5"], self.root, filename=filename, md5=_MD5["aesthetic_predictor_v2_5"])
160
+ self.model, self.preprocessor = convert_v2_5_from_siglip(
161
+ predictor_name_or_path=os.path.join(self.root, filename),
162
+ low_cpu_mem_usage=True,
163
+ trust_remote_code=True,
164
+ )
165
+ self.model = self.model.to(device=device, dtype=torch_dtype)
166
+ self.device = device
167
+ self.torch_dtype = torch_dtype
168
+
169
+ def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts=None) -> list[float]:
170
+ images, num_sampled_frames = reshape_images(images)
171
+
172
+ pixel_values = self.preprocessor(images, return_tensors="pt").pixel_values
173
+ pixel_values = pixel_values.to(self.device, self.torch_dtype)
174
+ with torch.no_grad():
175
+ scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
176
+
177
+ scores = scores.squeeze().tolist() # scalar or list
178
+ return reshape_scores(scores, num_sampled_frames)
179
+
180
+ def __repr__(self) -> str:
181
+ return "aesthetic_score_siglip"
182
+
183
+
184
+ class CLIPScore:
185
+ """Compute CLIP scores for image-text pairs based on huggingface/transformers."""
186
+
187
+ def __init__(
188
+ self,
189
+ model_name_or_path: str = "openai/clip-vit-large-patch14",
190
+ torch_dtype=torch.float16,
191
+ device: str = "cpu",
192
+ ):
193
+ self.model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).eval().to(device)
194
+ self.processor = AutoProcessor.from_pretrained(model_name_or_path)
195
+ self.torch_dtype = torch_dtype
196
+ self.device = device
197
+
198
+ def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts: list[str]) -> list[float]:
199
+ assert len(images) == len(texts)
200
+ images, num_sampled_frames = reshape_images(images)
201
+ # Expand texts in the batch video input case.
202
+ if num_sampled_frames is not None:
203
+ texts = [[text] * num_sampled_frames for text in texts]
204
+ texts = [item for sublist in texts for item in sublist]
205
+
206
+ image_inputs = self.processor(images=images, return_tensors="pt") # {"pixel_values": }
207
+ if self.torch_dtype == torch.float16:
208
+ image_inputs["pixel_values"] = image_inputs["pixel_values"].half()
209
+ text_inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True) # {"inputs_id": }
210
+ image_inputs, text_inputs = image_inputs.to(self.device), text_inputs.to(self.device)
211
+ with torch.no_grad():
212
+ image_embs = F.normalize(self.model.get_image_features(**image_inputs))
213
+ text_embs = F.normalize(self.model.get_text_features(**text_inputs))
214
+ scores = text_embs @ image_embs.T # [N, N]
215
+
216
+ scores = scores.squeeze().tolist() # scalar or list
217
+ return reshape_scores(scores, num_sampled_frames)
218
+
219
+ def __repr__(self) -> str:
220
+ return "clip_score"
221
+
222
+
223
+ if __name__ == "__main__":
224
+ from torch.utils.data import DataLoader
225
+ from tqdm import tqdm
226
+ from .video_dataset import VideoDataset, collate_fn
227
+
228
+ aesthetic_score = AestheticScore(device="cuda")
229
+ aesthetic_score_siglip = AestheticScoreSigLIP(device="cuda")
230
+ # clip_score = CLIPScore(device="cuda")
231
+
232
+ paths = ["your_image_path"] * 3
233
+ # texts = ["a joker", "a woman", "a man"]
234
+ images = [Image.open(p).convert("RGB") for p in paths]
235
+
236
+ print(aesthetic_score(images))
237
+ # print(clip_score(images, texts))
238
+
239
+ test_dataset = VideoDataset(
240
+ dataset_inputs={"video_path": ["your_video_path"] * 3},
241
+ sample_method="mid",
242
+ num_sampled_frames=2
243
+ )
244
+ test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, collate_fn=collate_fn)
245
+
246
+ for idx, batch in enumerate(tqdm(test_loader)):
247
+ batch_frame = batch["sampled_frame"]
248
+ print(aesthetic_score_siglip(batch_frame))
cogvideox/video_caption/utils/logger.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from sd-webui-controlnet/scripts/logging.py
2
+ import copy
3
+ import logging
4
+ import sys
5
+
6
+
7
+ class ColoredFormatter(logging.Formatter):
8
+ COLORS = {
9
+ "DEBUG": "\033[0;36m", # CYAN
10
+ "INFO": "\033[0;32m", # GREEN
11
+ "WARNING": "\033[0;33m", # YELLOW
12
+ "ERROR": "\033[0;31m", # RED
13
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
14
+ "RESET": "\033[0m", # RESET COLOR
15
+ }
16
+
17
+ def format(self, record):
18
+ colored_record = copy.copy(record)
19
+ levelname = colored_record.levelname
20
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
21
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
22
+ return super().format(colored_record)
23
+
24
+
25
+ # Create a new logger
26
+ logger = logging.getLogger("VideoCaption")
27
+ logger.propagate = False
28
+
29
+ # Add handler if we don't have one.
30
+ if not logger.handlers:
31
+ handler = logging.StreamHandler(sys.stdout)
32
+ handler.setFormatter(ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
33
+ logger.addHandler(handler)
34
+
35
+ # Configure logger
36
+ logger.setLevel("INFO")
cogvideox/video_caption/utils/longclip/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Long-CLIP
2
+ Codes in this directory are borrowed from https://github.com/beichenzbc/Long-CLIP/tree/4e6f5da/model.
3
+
4
+ We only modify the following code in [model_longclip.py](model_longclip.py) from
5
+ ```python
6
+ @property
7
+ def dtype(self):
8
+ return self.visual.conv1.weight.dtype
9
+ ```
10
+ to
11
+ ```python
12
+ @property
13
+ def dtype(self):
14
+ # Fix: the VideoCLIP-XL inference.
15
+ if hasattr(self, "visual"):
16
+ return self.visual.conv1.weight.dtype
17
+ else:
18
+ return self.token_embedding.weight.dtype
19
+ ```
cogvideox/video_caption/utils/longclip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .longclip import *
cogvideox/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
cogvideox/video_caption/utils/longclip/longclip.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+ from torch import nn
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model_longclip import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+
31
+ def _convert_image_to_rgb(image):
32
+ return image.convert("RGB")
33
+
34
+
35
+ def _transform(n_px):
36
+ return Compose([
37
+ Resize(n_px, interpolation=BICUBIC),
38
+ CenterCrop(n_px),
39
+ _convert_image_to_rgb,
40
+ ToTensor(),
41
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
42
+ ])
43
+
44
+
45
+
46
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", download_root: str = None):
47
+ """Load a long CLIP model
48
+
49
+ Parameters
50
+ ----------
51
+ name : str
52
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
53
+
54
+ device : Union[str, torch.device]
55
+ The device to put the loaded model
56
+
57
+ Returns
58
+ -------
59
+ model : torch.nn.Module
60
+ The CLIP model
61
+
62
+ preprocess : Callable[[PIL.Image], torch.Tensor]
63
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
64
+ """
65
+
66
+ model_path = name
67
+
68
+ state_dict = torch.load(model_path, map_location="cpu")
69
+
70
+ model = build_model(state_dict or model.state_dict(), load_from_clip = False).to(device)
71
+
72
+ if str(device) == "cpu":
73
+ model.float()
74
+
75
+ return model, _transform(model.visual.input_resolution)
76
+
77
+
78
+
79
+ def _node_get(node: torch._C.Node, key: str):
80
+ """Gets attributes of a node which is polymorphic over return type.
81
+
82
+ From https://github.com/pytorch/pytorch/pull/82628
83
+ """
84
+ sel = node.kindOf(key)
85
+ return getattr(node, sel)(key)
86
+
87
+ def patch_device(module):
88
+ try:
89
+ graphs = [module.graph] if hasattr(module, "graph") else []
90
+ except RuntimeError:
91
+ graphs = []
92
+
93
+ if hasattr(module, "forward1"):
94
+ graphs.append(module.forward1.graph)
95
+
96
+ for graph in graphs:
97
+ for node in graph.findAllNodes("prim::Constant"):
98
+ if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
99
+ node.copyAttributes(device_node)
100
+
101
+ model.apply(patch_device)
102
+ patch_device(model.encode_image)
103
+ patch_device(model.encode_text)
104
+
105
+ # patch dtype to float32 on CPU
106
+ if str(device) == "cpu":
107
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
108
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
109
+ float_node = float_input.node()
110
+
111
+ def patch_float(module):
112
+ try:
113
+ graphs = [module.graph] if hasattr(module, "graph") else []
114
+ except RuntimeError:
115
+ graphs = []
116
+
117
+ if hasattr(module, "forward1"):
118
+ graphs.append(module.forward1.graph)
119
+
120
+ for graph in graphs:
121
+ for node in graph.findAllNodes("aten::to"):
122
+ inputs = list(node.inputs())
123
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
124
+ if _node_get(inputs[i].node(), "value") == 5:
125
+ inputs[i].node().copyAttributes(float_node)
126
+
127
+ model.apply(patch_float)
128
+ patch_float(model.encode_image)
129
+ patch_float(model.encode_text)
130
+
131
+ model.float()
132
+
133
+ return model, _transform(model.input_resolution.item())
134
+
135
+
136
+ def load_from_clip(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
137
+ """Load from CLIP model for fine-tuning
138
+
139
+ Parameters
140
+ ----------
141
+ name : str
142
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
143
+
144
+ device : Union[str, torch.device]
145
+ The device to put the loaded model
146
+
147
+ jit : bool
148
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
149
+
150
+ download_root: str
151
+ path to download the model files; by default, it uses "~/.cache/clip"
152
+
153
+ Returns
154
+ -------
155
+ model : torch.nn.Module
156
+ The CLIP model
157
+
158
+ preprocess : Callable[[PIL.Image], torch.Tensor]
159
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
160
+ """
161
+
162
+ _MODELS = {
163
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
164
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
165
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
166
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
167
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
168
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
169
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
170
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
171
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
172
+ }
173
+
174
+ def available_models() -> List[str]:
175
+ """Returns the names of available CLIP models"""
176
+ return list(_MODELS.keys())
177
+
178
+ def _download(url: str, root: str):
179
+ os.makedirs(root, exist_ok=True)
180
+ filename = os.path.basename(url)
181
+
182
+ expected_sha256 = url.split("/")[-2]
183
+ download_target = os.path.join(root, filename)
184
+
185
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
186
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
187
+
188
+ if os.path.isfile(download_target):
189
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
190
+ return download_target
191
+ else:
192
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
193
+
194
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
195
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
196
+ while True:
197
+ buffer = source.read(8192)
198
+ if not buffer:
199
+ break
200
+
201
+ output.write(buffer)
202
+ loop.update(len(buffer))
203
+
204
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
205
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
206
+
207
+ return download_target
208
+
209
+ if name in _MODELS:
210
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
211
+ elif os.path.isfile(name):
212
+ model_path = name
213
+ else:
214
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
215
+
216
+ with open(model_path, 'rb') as opened_file:
217
+ try:
218
+ # loading JIT archive
219
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
220
+ state_dict = None
221
+ except RuntimeError:
222
+ # loading saved state dict
223
+ if jit:
224
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
225
+ jit = False
226
+ state_dict = torch.load(opened_file, map_location="cpu")
227
+
228
+ model = build_model(state_dict or model.state_dict(), load_from_clip = True).to(device)
229
+
230
+ positional_embedding_pre = model.positional_embedding.type(model.dtype)
231
+
232
+ length, dim = positional_embedding_pre.shape
233
+ keep_len = 20
234
+ posisitonal_embedding_new = torch.zeros([4*length-3*keep_len, dim], dtype=model.dtype)
235
+ for i in range(keep_len):
236
+ posisitonal_embedding_new[i] = positional_embedding_pre[i]
237
+ for i in range(length-1-keep_len):
238
+ posisitonal_embedding_new[4*i + keep_len] = positional_embedding_pre[i + keep_len]
239
+ posisitonal_embedding_new[4*i + 1 + keep_len] = 3*positional_embedding_pre[i + keep_len]/4 + 1*positional_embedding_pre[i+1+keep_len]/4
240
+ posisitonal_embedding_new[4*i + 2+keep_len] = 2*positional_embedding_pre[i+keep_len]/4 + 2*positional_embedding_pre[i+1+keep_len]/4
241
+ posisitonal_embedding_new[4*i + 3+keep_len] = 1*positional_embedding_pre[i+keep_len]/4 + 3*positional_embedding_pre[i+1+keep_len]/4
242
+
243
+ posisitonal_embedding_new[4*length -3*keep_len - 4] = positional_embedding_pre[length-1] + 0*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
244
+ posisitonal_embedding_new[4*length -3*keep_len - 3] = positional_embedding_pre[length-1] + 1*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
245
+ posisitonal_embedding_new[4*length -3*keep_len - 2] = positional_embedding_pre[length-1] + 2*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
246
+ posisitonal_embedding_new[4*length -3*keep_len - 1] = positional_embedding_pre[length-1] + 3*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
247
+
248
+ positional_embedding_res = posisitonal_embedding_new.clone()
249
+
250
+ model.positional_embedding = nn.Parameter(posisitonal_embedding_new, requires_grad=False)
251
+ model.positional_embedding_res = nn.Parameter(positional_embedding_res, requires_grad=True)
252
+
253
+ if str(device) == "cpu":
254
+ model.float()
255
+ return model, _transform(model.visual.input_resolution)
256
+
257
+ def _node_get(node: torch._C.Node, key: str):
258
+ """Gets attributes of a node which is polymorphic over return type.
259
+
260
+ From https://github.com/pytorch/pytorch/pull/82628
261
+ """
262
+ sel = node.kindOf(key)
263
+ return getattr(node, sel)(key)
264
+
265
+ def patch_device(module):
266
+ try:
267
+ graphs = [module.graph] if hasattr(module, "graph") else []
268
+ except RuntimeError:
269
+ graphs = []
270
+
271
+ if hasattr(module, "forward1"):
272
+ graphs.append(module.forward1.graph)
273
+
274
+ for graph in graphs:
275
+ for node in graph.findAllNodes("prim::Constant"):
276
+ if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
277
+ node.copyAttributes(device_node)
278
+
279
+ model.apply(patch_device)
280
+ patch_device(model.encode_image)
281
+ patch_device(model.encode_text)
282
+
283
+ # patch dtype to float32 on CPU
284
+ if str(device) == "cpu":
285
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
286
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
287
+ float_node = float_input.node()
288
+
289
+ def patch_float(module):
290
+ try:
291
+ graphs = [module.graph] if hasattr(module, "graph") else []
292
+ except RuntimeError:
293
+ graphs = []
294
+
295
+ if hasattr(module, "forward1"):
296
+ graphs.append(module.forward1.graph)
297
+
298
+ for graph in graphs:
299
+ for node in graph.findAllNodes("aten::to"):
300
+ inputs = list(node.inputs())
301
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
302
+ if _node_get(inputs[i].node(), "value") == 5:
303
+ inputs[i].node().copyAttributes(float_node)
304
+
305
+ model.apply(patch_float)
306
+ patch_float(model.encode_image)
307
+ patch_float(model.encode_text)
308
+
309
+ model.float()
310
+
311
+ return model, _transform(model.input_resolution.item())
312
+
313
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
314
+ """
315
+ Returns the tokenized representation of given input string(s)
316
+
317
+ Parameters
318
+ ----------
319
+ texts : Union[str, List[str]]
320
+ An input string or a list of input strings to tokenize
321
+
322
+ context_length : int
323
+ The context length to use; all CLIP models use 77 as the context length
324
+
325
+ truncate: bool
326
+ Whether to truncate the text in case its encoding is longer than the context length
327
+
328
+ Returns
329
+ -------
330
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
331
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
332
+ """
333
+ if isinstance(texts, str):
334
+ texts = [texts]
335
+
336
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
337
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
338
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
339
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
340
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
341
+ else:
342
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
343
+
344
+ for i, tokens in enumerate(all_tokens):
345
+ if len(tokens) > context_length:
346
+ if truncate:
347
+ tokens = tokens[:context_length]
348
+ tokens[-1] = eot_token
349
+ else:
350
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
351
+ result[i, :len(tokens)] = torch.tensor(tokens)
352
+
353
+ return result
cogvideox/video_caption/utils/longclip/model_longclip.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisionTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ x = self.ln_post(x[:, 0, :])
236
+
237
+ if self.proj is not None:
238
+ x = x @ self.proj
239
+
240
+ return x
241
+
242
+
243
+ class CLIP(nn.Module):
244
+ def __init__(self,
245
+ embed_dim: int,
246
+ # vision
247
+ image_resolution: int,
248
+ vision_layers: Union[Tuple[int, int, int, int], int],
249
+ vision_width: int,
250
+ vision_patch_size: int,
251
+ # text
252
+ context_length: int,
253
+ vocab_size: int,
254
+ transformer_width: int,
255
+ transformer_heads: int,
256
+ transformer_layers: int,
257
+ load_from_clip: bool
258
+ ):
259
+ super().__init__()
260
+
261
+ self.context_length = 248
262
+
263
+ if isinstance(vision_layers, (tuple, list)):
264
+ vision_heads = vision_width * 32 // 64
265
+ self.visual = ModifiedResNet(
266
+ layers=vision_layers,
267
+ output_dim=embed_dim,
268
+ heads=vision_heads,
269
+ input_resolution=image_resolution,
270
+ width=vision_width
271
+ )
272
+ else:
273
+ vision_heads = vision_width // 64
274
+ self.visual = VisionTransformer(
275
+ input_resolution=image_resolution,
276
+ patch_size=vision_patch_size,
277
+ width=vision_width,
278
+ layers=vision_layers,
279
+ heads=vision_heads,
280
+ output_dim=embed_dim
281
+ )
282
+
283
+ self.transformer = Transformer(
284
+ width=transformer_width,
285
+ layers=transformer_layers,
286
+ heads=transformer_heads,
287
+ attn_mask=self.build_attention_mask()
288
+ )
289
+
290
+ self.vocab_size = vocab_size
291
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
292
+
293
+ if load_from_clip == False:
294
+ self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
295
+ self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width))
296
+
297
+ else:
298
+ self.positional_embedding = nn.Parameter(torch.empty(77, transformer_width))
299
+
300
+ self.ln_final = LayerNorm(transformer_width)
301
+
302
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
303
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
304
+
305
+ self.initialize_parameters()
306
+ self.mask1 = torch.zeros([248, 1])
307
+ self.mask1[:20, :] = 1
308
+ self.mask2 = torch.zeros([248, 1])
309
+ self.mask2[20:, :] = 1
310
+
311
+
312
+ def initialize_parameters(self):
313
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
314
+ nn.init.normal_(self.positional_embedding, std=0.01)
315
+
316
+ if isinstance(self.visual, ModifiedResNet):
317
+ if self.visual.attnpool is not None:
318
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
319
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
320
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
321
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
322
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
323
+
324
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
325
+ for name, param in resnet_block.named_parameters():
326
+ if name.endswith("bn3.weight"):
327
+ nn.init.zeros_(param)
328
+
329
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
330
+ attn_std = self.transformer.width ** -0.5
331
+ fc_std = (2 * self.transformer.width) ** -0.5
332
+ for block in self.transformer.resblocks:
333
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
334
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
335
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
336
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
337
+
338
+ if self.text_projection is not None:
339
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
340
+
341
+ def build_attention_mask(self):
342
+ # lazily create causal attention mask, with full attention between the vision tokens
343
+ # pytorch uses additive attention mask; fill with -inf
344
+ mask = torch.empty(self.context_length, self.context_length)
345
+ mask.fill_(float("-inf"))
346
+ mask.triu_(1) # zero out the lower diagonal
347
+ return mask
348
+
349
+ @property
350
+ def dtype(self):
351
+ # Fix: the mixclip inference.
352
+ if hasattr(self, "visual"):
353
+ return self.visual.conv1.weight.dtype
354
+ else:
355
+ return self.token_embedding.weight.dtype
356
+
357
+ def encode_image(self, image):
358
+ return self.visual(image.type(self.dtype))
359
+
360
+ def encode_text(self, text):
361
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
362
+
363
+ x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
364
+
365
+ x = x.permute(1, 0, 2) # NLD -> LND
366
+ x = self.transformer(x)
367
+ x = x.permute(1, 0, 2) # LND -> NLD
368
+ x = self.ln_final(x).type(self.dtype)
369
+
370
+ # x.shape = [batch_size, n_ctx, transformer.width]
371
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
372
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
373
+
374
+ return x
375
+
376
+ def encode_text_full(self, text):
377
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
378
+
379
+ x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
380
+
381
+ x = x.permute(1, 0, 2) # NLD -> LND
382
+ x = self.transformer(x)
383
+ x = x.permute(1, 0, 2) # LND -> NLD
384
+ x = self.ln_final(x).type(self.dtype)
385
+
386
+ # x.shape = [batch_size, n_ctx, transformer.width]
387
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
388
+ #x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
389
+
390
+ return x
391
+
392
+
393
+ def forward(self, image, text):
394
+ image_features = self.encode_image(image)
395
+ text_features = self.encode_text(text)
396
+
397
+ # normalized features
398
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
399
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
400
+
401
+ # cosine similarity as logits
402
+ logit_scale = self.logit_scale.exp()
403
+ logits_per_image = logit_scale * image_features @ text_features.t()
404
+ logits_per_text = logits_per_image.t()
405
+
406
+ # shape = [global_batch_size, global_batch_size]
407
+ return logits_per_image, logits_per_text
408
+
409
+
410
+ def convert_weights(model: nn.Module):
411
+ """Convert applicable model parameters to fp16"""
412
+
413
+ def _convert_weights_to_fp16(l):
414
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
415
+ l.weight.data = l.weight.data.half()
416
+ if l.bias is not None:
417
+ l.bias.data = l.bias.data.half()
418
+
419
+ if isinstance(l, nn.MultiheadAttention):
420
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
421
+ tensor = getattr(l, attr)
422
+ if tensor is not None:
423
+ tensor.data = tensor.data.half()
424
+
425
+ for name in ["text_projection", "proj"]:
426
+ if hasattr(l, name):
427
+ attr = getattr(l, name)
428
+ if attr is not None:
429
+ attr.data = attr.data.half()
430
+
431
+ model.apply(_convert_weights_to_fp16)
432
+
433
+
434
+ def build_model(state_dict: dict, load_from_clip: bool):
435
+ vit = "visual.proj" in state_dict
436
+
437
+ if vit:
438
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
439
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
440
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
441
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
442
+ image_resolution = vision_patch_size * grid_size
443
+ else:
444
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
445
+ vision_layers = tuple(counts)
446
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
447
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
448
+ vision_patch_size = None
449
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
450
+ image_resolution = output_width * 32
451
+
452
+ embed_dim = state_dict["text_projection"].shape[1]
453
+ context_length = state_dict["positional_embedding"].shape[0]
454
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
455
+ transformer_width = state_dict["ln_final.weight"].shape[0]
456
+ transformer_heads = transformer_width // 64
457
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
458
+
459
+ model = CLIP(
460
+ embed_dim,
461
+ image_resolution, vision_layers, vision_width, vision_patch_size,
462
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, load_from_clip
463
+ )
464
+
465
+ for key in ["input_resolution", "context_length", "vocab_size"]:
466
+ if key in state_dict:
467
+ del state_dict[key]
468
+
469
+ convert_weights(model)
470
+ model.load_state_dict(state_dict)
471
+ return model.eval()
cogvideox/video_caption/utils/longclip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
cogvideox/video_caption/utils/siglip_v2_5.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py.
2
+ import os
3
+ from collections import OrderedDict
4
+ from os import PathLike
5
+ from typing import Final
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import (
10
+ SiglipImageProcessor,
11
+ SiglipVisionConfig,
12
+ SiglipVisionModel,
13
+ logging,
14
+ )
15
+ from transformers.image_processing_utils import BatchFeature
16
+ from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
17
+
18
+ logging.set_verbosity_error()
19
+
20
+ URL: Final[str] = (
21
+ "https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth"
22
+ )
23
+
24
+
25
+ class AestheticPredictorV2_5Head(nn.Module):
26
+ def __init__(self, config: SiglipVisionConfig) -> None:
27
+ super().__init__()
28
+ self.scoring_head = nn.Sequential(
29
+ nn.Linear(config.hidden_size, 1024),
30
+ nn.Dropout(0.5),
31
+ nn.Linear(1024, 128),
32
+ nn.Dropout(0.5),
33
+ nn.Linear(128, 64),
34
+ nn.Dropout(0.5),
35
+ nn.Linear(64, 16),
36
+ nn.Dropout(0.2),
37
+ nn.Linear(16, 1),
38
+ )
39
+
40
+ def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
41
+ return self.scoring_head(image_embeds)
42
+
43
+
44
+ class AestheticPredictorV2_5Model(SiglipVisionModel):
45
+ PATCH_SIZE = 14
46
+
47
+ def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None:
48
+ super().__init__(config, *args, **kwargs)
49
+ self.layers = AestheticPredictorV2_5Head(config)
50
+ self.post_init()
51
+
52
+ def forward(
53
+ self,
54
+ pixel_values: torch.FloatTensor | None = None,
55
+ labels: torch.Tensor | None = None,
56
+ return_dict: bool | None = None,
57
+ ) -> tuple | ImageClassifierOutputWithNoAttention:
58
+ return_dict = (
59
+ return_dict if return_dict is not None else self.config.use_return_dict
60
+ )
61
+
62
+ outputs = super().forward(
63
+ pixel_values=pixel_values,
64
+ return_dict=return_dict,
65
+ )
66
+ image_embeds = outputs.pooler_output
67
+ image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
68
+ prediction = self.layers(image_embeds_norm)
69
+
70
+ loss = None
71
+ if labels is not None:
72
+ loss_fct = nn.MSELoss()
73
+ loss = loss_fct()
74
+
75
+ if not return_dict:
76
+ return (loss, prediction, image_embeds)
77
+
78
+ return ImageClassifierOutputWithNoAttention(
79
+ loss=loss,
80
+ logits=prediction,
81
+ hidden_states=image_embeds,
82
+ )
83
+
84
+
85
+ class AestheticPredictorV2_5Processor(SiglipImageProcessor):
86
+ def __init__(self, *args, **kwargs) -> None:
87
+ super().__init__(*args, **kwargs)
88
+
89
+ def __call__(self, *args, **kwargs) -> BatchFeature:
90
+ return super().__call__(*args, **kwargs)
91
+
92
+ @classmethod
93
+ def from_pretrained(
94
+ self,
95
+ pretrained_model_name_or_path: str
96
+ | PathLike = "google/siglip-so400m-patch14-384",
97
+ *args,
98
+ **kwargs,
99
+ ) -> "AestheticPredictorV2_5Processor":
100
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
101
+
102
+
103
+ def convert_v2_5_from_siglip(
104
+ predictor_name_or_path: str | PathLike | None = None,
105
+ encoder_model_name: str = "google/siglip-so400m-patch14-384",
106
+ *args,
107
+ **kwargs,
108
+ ) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]:
109
+ model = AestheticPredictorV2_5Model.from_pretrained(
110
+ encoder_model_name, *args, **kwargs
111
+ )
112
+
113
+ processor = AestheticPredictorV2_5Processor.from_pretrained(
114
+ encoder_model_name, *args, **kwargs
115
+ )
116
+
117
+ if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path):
118
+ state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu")
119
+ else:
120
+ state_dict = torch.load(predictor_name_or_path, map_location="cpu")
121
+
122
+ assert isinstance(state_dict, OrderedDict)
123
+
124
+ model.layers.load_state_dict(state_dict)
125
+ model.eval()
126
+
127
+ return model, processor