layerdiffusion
commited on
Commit
•
9ab270d
1
Parent(s):
fdeb859
- LICENSE +201 -0
- app.py +357 -8
- chat_interface.py +628 -0
- lib_omost/canvas.py +248 -0
- lib_omost/pipeline.py +435 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
CHANGED
@@ -1,14 +1,363 @@
|
|
1 |
-
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
zero = torch.Tensor([0]).cuda()
|
6 |
-
print(zero.device) # <-- 'cpu' 🤔
|
7 |
|
8 |
@spaces.GPU
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
demo.launch()
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
#
|
3 |
+
# import torch
|
4 |
+
#
|
5 |
+
# zero = torch.Tensor([0]).cuda()
|
6 |
+
# print(zero.device) # <-- 'cpu' 🤔
|
7 |
+
#
|
8 |
+
# @spaces.GPU
|
9 |
+
# def greet(n):
|
10 |
+
# print(zero.device) # <-- 'cuda:0' 🤗
|
11 |
+
# return f"Hello {zero + n} Tensor"
|
12 |
+
#
|
13 |
+
# demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
|
14 |
+
# demo.launch()
|
15 |
+
|
16 |
+
import os
|
17 |
import spaces
|
18 |
+
|
19 |
+
os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
|
20 |
+
HF_TOKEN = os.environ['hf_token'] if 'hf_token' in os.environ else None
|
21 |
+
|
22 |
+
import uuid
|
23 |
import torch
|
24 |
+
import numpy as np
|
25 |
+
import gradio as gr
|
26 |
+
import tempfile
|
27 |
+
|
28 |
+
gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio')
|
29 |
+
os.makedirs(gradio_temp_dir, exist_ok=True)
|
30 |
+
|
31 |
+
from threading import Thread
|
32 |
+
|
33 |
+
# Phi3 Hijack
|
34 |
+
from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel
|
35 |
+
|
36 |
+
Phi3PreTrainedModel._supports_sdpa = True
|
37 |
+
|
38 |
+
from PIL import Image
|
39 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
40 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
41 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
42 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
43 |
+
from lib_omost.pipeline import StableDiffusionXLOmostPipeline
|
44 |
+
from chat_interface import ChatInterface
|
45 |
+
|
46 |
+
import lib_omost.canvas as omost_canvas
|
47 |
+
|
48 |
+
|
49 |
+
# SDXL
|
50 |
+
|
51 |
+
sdxl_name = 'SG161222/RealVisXL_V4.0'
|
52 |
+
# sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
|
53 |
+
|
54 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
55 |
+
sdxl_name, subfolder="tokenizer")
|
56 |
+
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
57 |
+
sdxl_name, subfolder="tokenizer_2")
|
58 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
59 |
+
sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto")
|
60 |
+
text_encoder_2 = CLIPTextModel.from_pretrained(
|
61 |
+
sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto")
|
62 |
+
vae = AutoencoderKL.from_pretrained(
|
63 |
+
sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae
|
64 |
+
unet = UNet2DConditionModel.from_pretrained(
|
65 |
+
sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto")
|
66 |
+
|
67 |
+
unet.set_attn_processor(AttnProcessor2_0())
|
68 |
+
vae.set_attn_processor(AttnProcessor2_0())
|
69 |
+
|
70 |
+
pipeline = StableDiffusionXLOmostPipeline(
|
71 |
+
vae=vae,
|
72 |
+
text_encoder=text_encoder,
|
73 |
+
tokenizer=tokenizer,
|
74 |
+
text_encoder_2=text_encoder_2,
|
75 |
+
tokenizer_2=tokenizer_2,
|
76 |
+
unet=unet,
|
77 |
+
scheduler=None, # We completely give up diffusers sampling system and use A1111's method
|
78 |
+
)
|
79 |
+
|
80 |
+
# LLM
|
81 |
+
|
82 |
+
# model_name = 'lllyasviel/omost-phi-3-mini-128k-8bits'
|
83 |
+
llm_name = 'lllyasviel/omost-llama-3-8b-4bits'
|
84 |
+
# model_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b-4bits'
|
85 |
+
|
86 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
87 |
+
llm_name,
|
88 |
+
torch_dtype=torch.bfloat16, # This is computation type, not load/memory type. The loading quant type is baked in config.
|
89 |
+
token=HF_TOKEN,
|
90 |
+
device_map="auto"
|
91 |
+
)
|
92 |
+
|
93 |
+
llm_tokenizer = AutoTokenizer.from_pretrained(
|
94 |
+
llm_name,
|
95 |
+
token=HF_TOKEN
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
@torch.inference_mode()
|
100 |
+
def pytorch2numpy(imgs):
|
101 |
+
results = []
|
102 |
+
for x in imgs:
|
103 |
+
y = x.movedim(0, -1)
|
104 |
+
y = y * 127.5 + 127.5
|
105 |
+
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
106 |
+
results.append(y)
|
107 |
+
return results
|
108 |
+
|
109 |
+
|
110 |
+
@torch.inference_mode()
|
111 |
+
def numpy2pytorch(imgs):
|
112 |
+
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
113 |
+
h = h.movedim(-1, 1)
|
114 |
+
return h
|
115 |
+
|
116 |
+
|
117 |
+
def resize_without_crop(image, target_width, target_height):
|
118 |
+
pil_image = Image.fromarray(image)
|
119 |
+
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
120 |
+
return np.array(resized_image)
|
121 |
|
|
|
|
|
122 |
|
123 |
@spaces.GPU
|
124 |
+
@torch.inference_mode()
|
125 |
+
def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str:
|
126 |
+
np.random.seed(int(seed))
|
127 |
+
torch.manual_seed(int(seed))
|
128 |
+
|
129 |
+
conversation = [{"role": "system", "content": omost_canvas.system_prompt}]
|
130 |
+
|
131 |
+
for user, assistant in history:
|
132 |
+
if user is None or assistant is None:
|
133 |
+
continue
|
134 |
+
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
135 |
+
|
136 |
+
conversation.append({"role": "user", "content": message})
|
137 |
+
|
138 |
+
input_ids = llm_tokenizer.apply_chat_template(
|
139 |
+
conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device)
|
140 |
+
|
141 |
+
streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
142 |
+
|
143 |
+
generate_kwargs = dict(
|
144 |
+
input_ids=input_ids,
|
145 |
+
streamer=streamer,
|
146 |
+
max_new_tokens=max_new_tokens,
|
147 |
+
do_sample=True,
|
148 |
+
temperature=temperature,
|
149 |
+
top_p=top_p,
|
150 |
+
)
|
151 |
+
|
152 |
+
if temperature == 0:
|
153 |
+
generate_kwargs['do_sample'] = False
|
154 |
+
|
155 |
+
Thread(target=llm_model.generate, kwargs=generate_kwargs).start()
|
156 |
+
|
157 |
+
outputs = []
|
158 |
+
for text in streamer:
|
159 |
+
outputs.append(text)
|
160 |
+
# print(outputs)
|
161 |
+
yield "".join(outputs)
|
162 |
+
|
163 |
+
return
|
164 |
+
|
165 |
+
|
166 |
+
@torch.inference_mode()
|
167 |
+
def post_chat(history):
|
168 |
+
history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)]
|
169 |
+
last_assistant = history[-1][1]
|
170 |
+
canvas_outputs = None
|
171 |
+
|
172 |
+
try:
|
173 |
+
canvas = omost_canvas.Canvas.from_bot_response(last_assistant)
|
174 |
+
canvas_outputs = canvas.process()
|
175 |
+
except Exception as e:
|
176 |
+
print('Last assistant response is not valid canvas:', e)
|
177 |
+
|
178 |
+
return canvas_outputs, gr.update(visible=canvas_outputs is not None)
|
179 |
+
|
180 |
+
|
181 |
+
@spaces.GPU
|
182 |
+
@torch.inference_mode()
|
183 |
+
def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height,
|
184 |
+
highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt):
|
185 |
+
|
186 |
+
use_initial_latent = False
|
187 |
+
eps = 0.05
|
188 |
+
|
189 |
+
image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64
|
190 |
+
|
191 |
+
rng = torch.Generator(unet.device).manual_seed(seed)
|
192 |
+
|
193 |
+
positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt)
|
194 |
+
|
195 |
+
if use_initial_latent:
|
196 |
+
initial_latent = torch.from_numpy(canvas_outputs['initial_latent'])[None].movedim(-1, 1) / 127.5 - 1.0
|
197 |
+
initial_latent_blur = 40
|
198 |
+
initial_latent = torch.nn.functional.avg_pool2d(
|
199 |
+
torch.nn.functional.pad(initial_latent, (initial_latent_blur,) * 4, mode='reflect'),
|
200 |
+
kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1))
|
201 |
+
initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width))
|
202 |
+
initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device)
|
203 |
+
initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor
|
204 |
+
else:
|
205 |
+
initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32)
|
206 |
+
|
207 |
+
initial_latent = initial_latent.to(dtype=unet.dtype, device=unet.device)
|
208 |
+
|
209 |
+
latents = pipeline(
|
210 |
+
initial_latent=initial_latent,
|
211 |
+
strength=1.0,
|
212 |
+
num_inference_steps=int(steps),
|
213 |
+
batch_size=num_samples,
|
214 |
+
prompt_embeds=positive_cond,
|
215 |
+
negative_prompt_embeds=negative_cond,
|
216 |
+
pooled_prompt_embeds=positive_pooler,
|
217 |
+
negative_pooled_prompt_embeds=negative_pooler,
|
218 |
+
generator=rng,
|
219 |
+
guidance_scale=float(cfg),
|
220 |
+
).images
|
221 |
+
|
222 |
+
latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
|
223 |
+
pixels = vae.decode(latents).sample
|
224 |
+
B, C, H, W = pixels.shape
|
225 |
+
pixels = pytorch2numpy(pixels)
|
226 |
+
|
227 |
+
if highres_scale > 1.0 + eps:
|
228 |
+
pixels = [
|
229 |
+
resize_without_crop(
|
230 |
+
image=p,
|
231 |
+
target_width=int(round(W * highres_scale / 64.0) * 64),
|
232 |
+
target_height=int(round(H * highres_scale / 64.0) * 64)
|
233 |
+
) for p in pixels
|
234 |
+
]
|
235 |
+
|
236 |
+
pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
|
237 |
+
latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
|
238 |
+
|
239 |
+
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
240 |
+
|
241 |
+
latents = pipeline(
|
242 |
+
initial_latent=latents,
|
243 |
+
strength=highres_denoise,
|
244 |
+
num_inference_steps=highres_steps,
|
245 |
+
batch_size=num_samples,
|
246 |
+
prompt_embeds=positive_cond,
|
247 |
+
negative_prompt_embeds=negative_cond,
|
248 |
+
pooled_prompt_embeds=positive_pooler,
|
249 |
+
negative_pooled_prompt_embeds=negative_pooler,
|
250 |
+
generator=rng,
|
251 |
+
guidance_scale=float(cfg),
|
252 |
+
).images
|
253 |
+
|
254 |
+
latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
|
255 |
+
pixels = vae.decode(latents).sample
|
256 |
+
pixels = pytorch2numpy(pixels)
|
257 |
+
|
258 |
+
for i in range(len(pixels)):
|
259 |
+
unique_hex = uuid.uuid4().hex
|
260 |
+
image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png")
|
261 |
+
image = Image.fromarray(pixels[i])
|
262 |
+
image.save(image_path)
|
263 |
+
chatbot = chatbot + [(None, (image_path, 'image'))]
|
264 |
+
|
265 |
+
return chatbot
|
266 |
+
|
267 |
+
|
268 |
+
css = '''
|
269 |
+
code {white-space: pre-wrap !important;}
|
270 |
+
.gradio-container {max-width: none !important;}
|
271 |
+
.outer_parent {flex: 1;}
|
272 |
+
.inner_parent {flex: 1;}
|
273 |
+
footer {display: none !important; visibility: hidden !important;}
|
274 |
+
.translucent {display: none !important; visibility: hidden !important;}
|
275 |
+
'''
|
276 |
+
|
277 |
+
with gr.Blocks(fill_height=True, css=css) as demo:
|
278 |
+
with gr.Row(elem_classes='outer_parent'):
|
279 |
+
with gr.Column(scale=25):
|
280 |
+
with gr.Row():
|
281 |
+
retry_btn = gr.Button("🔄 Retry", variant="secondary", size="sm", min_width=60)
|
282 |
+
undo_btn = gr.Button("↩️ Undo", variant="secondary", size="sm", min_width=60)
|
283 |
+
clear_btn = gr.Button("⭐️ New Chat", variant="secondary", size="sm", min_width=60)
|
284 |
+
|
285 |
+
seed = gr.Number(label="Random Seed", value=12345, precision=0)
|
286 |
+
|
287 |
+
with gr.Accordion(open=True, label='Language Model'):
|
288 |
+
with gr.Group():
|
289 |
+
with gr.Row():
|
290 |
+
temperature = gr.Slider(
|
291 |
+
minimum=0.0,
|
292 |
+
maximum=2.0,
|
293 |
+
step=0.01,
|
294 |
+
value=0.6,
|
295 |
+
label="Temperature")
|
296 |
+
top_p = gr.Slider(
|
297 |
+
minimum=0.0,
|
298 |
+
maximum=1.0,
|
299 |
+
step=0.01,
|
300 |
+
value=0.9,
|
301 |
+
label="Top P")
|
302 |
+
max_new_tokens = gr.Slider(
|
303 |
+
minimum=128,
|
304 |
+
maximum=4096,
|
305 |
+
step=1,
|
306 |
+
value=4096,
|
307 |
+
label="Max New Tokens")
|
308 |
+
with gr.Accordion(open=True, label='Image Diffusion Model'):
|
309 |
+
with gr.Group():
|
310 |
+
with gr.Row():
|
311 |
+
image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64)
|
312 |
+
image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64)
|
313 |
+
|
314 |
+
with gr.Row():
|
315 |
+
num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1)
|
316 |
+
steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1)
|
317 |
+
|
318 |
+
with gr.Accordion(open=False, label='Advanced'):
|
319 |
+
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01)
|
320 |
+
highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01)
|
321 |
+
highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1)
|
322 |
+
highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01)
|
323 |
+
n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
|
324 |
+
|
325 |
+
render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False)
|
326 |
+
|
327 |
+
examples = gr.Dataset(
|
328 |
+
samples=[
|
329 |
+
['generate an image of the fierce battle of warriors and a dragon'],
|
330 |
+
['change the dragon to a dinosaur']
|
331 |
+
],
|
332 |
+
components=[gr.Textbox(visible=False)],
|
333 |
+
label='Quick Prompts'
|
334 |
+
)
|
335 |
+
with gr.Column(scale=75, elem_classes='inner_parent'):
|
336 |
+
canvas_state = gr.State(None)
|
337 |
+
chatbot = gr.Chatbot(label='Omost', scale=1, bubble_full_width=True, render=False)
|
338 |
+
chatInterface = ChatInterface(
|
339 |
+
fn=chat_fn,
|
340 |
+
post_fn=post_chat,
|
341 |
+
post_fn_kwargs=dict(inputs=[chatbot], outputs=[canvas_state, render_button]),
|
342 |
+
pre_fn=lambda: gr.update(visible=False),
|
343 |
+
pre_fn_kwargs=dict(outputs=[render_button]),
|
344 |
+
chatbot=chatbot,
|
345 |
+
retry_btn=retry_btn,
|
346 |
+
undo_btn=undo_btn,
|
347 |
+
clear_btn=clear_btn,
|
348 |
+
additional_inputs=[seed, temperature, top_p, max_new_tokens],
|
349 |
+
examples=examples
|
350 |
+
)
|
351 |
+
|
352 |
+
render_button.click(
|
353 |
+
fn=diffusion_fn, inputs=[
|
354 |
+
chatInterface.chatbot, canvas_state,
|
355 |
+
num_samples, seed, image_width, image_height, highres_scale,
|
356 |
+
steps, cfg, highres_steps, highres_denoise, n_prompt
|
357 |
+
], outputs=[chatInterface.chatbot]).then(
|
358 |
+
fn=lambda x: x, inputs=[
|
359 |
+
chatInterface.chatbot
|
360 |
+
], outputs=[chatInterface.chatbot_state])
|
361 |
|
362 |
+
if __name__ == "__main__":
|
363 |
+
demo.queue().launch(inbrowser=True, server_name='0.0.0.0')
|
chat_interface.py
ADDED
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import inspect
|
8 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
9 |
+
|
10 |
+
import anyio
|
11 |
+
from gradio_client.documentation import document
|
12 |
+
|
13 |
+
from gradio.blocks import Blocks
|
14 |
+
from gradio.components import (
|
15 |
+
Button,
|
16 |
+
Chatbot,
|
17 |
+
Component,
|
18 |
+
Markdown,
|
19 |
+
MultimodalTextbox,
|
20 |
+
State,
|
21 |
+
Textbox,
|
22 |
+
get_component_instance,
|
23 |
+
Dataset
|
24 |
+
)
|
25 |
+
from gradio.events import Dependency, on
|
26 |
+
from gradio.helpers import special_args
|
27 |
+
from gradio.layouts import Accordion, Group, Row
|
28 |
+
from gradio.routes import Request
|
29 |
+
from gradio.themes import ThemeClass as Theme
|
30 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
|
31 |
+
|
32 |
+
|
33 |
+
@document()
|
34 |
+
class ChatInterface(Blocks):
|
35 |
+
"""
|
36 |
+
ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
|
37 |
+
a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
|
38 |
+
takes a function that governs the response of the chatbot based on the user input and chat history. Additional
|
39 |
+
parameters can be used to control the appearance and behavior of the demo.
|
40 |
+
|
41 |
+
Example:
|
42 |
+
import gradio as gr
|
43 |
+
|
44 |
+
def echo(message, history):
|
45 |
+
return message
|
46 |
+
|
47 |
+
demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
|
48 |
+
demo.launch()
|
49 |
+
Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo
|
50 |
+
Guides: creating-a-chatbot-fast, sharing-your-app
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
fn: Callable,
|
56 |
+
post_fn: Callable,
|
57 |
+
pre_fn: Callable,
|
58 |
+
chatbot: Chatbot,
|
59 |
+
*,
|
60 |
+
post_fn_kwargs: dict = None,
|
61 |
+
pre_fn_kwargs: dict = None,
|
62 |
+
multimodal: bool = False,
|
63 |
+
textbox: Textbox | MultimodalTextbox | None = None,
|
64 |
+
additional_inputs: str | Component | list[str | Component] | None = None,
|
65 |
+
additional_inputs_accordion_name: str | None = None,
|
66 |
+
additional_inputs_accordion: str | Accordion | None = None,
|
67 |
+
examples: Dataset = None,
|
68 |
+
title: str | None = None,
|
69 |
+
description: str | None = None,
|
70 |
+
theme: Theme | str | None = None,
|
71 |
+
css: str | None = None,
|
72 |
+
js: str | None = None,
|
73 |
+
head: str | None = None,
|
74 |
+
analytics_enabled: bool | None = None,
|
75 |
+
submit_btn: str | None | Button = "Submit",
|
76 |
+
stop_btn: str | None | Button = "Stop",
|
77 |
+
retry_btn: str | None | Button = "🔄 Retry",
|
78 |
+
undo_btn: str | None | Button = "↩️ Undo",
|
79 |
+
clear_btn: str | None | Button = "🗑️ Clear",
|
80 |
+
autofocus: bool = True,
|
81 |
+
concurrency_limit: int | None | Literal["default"] = "default",
|
82 |
+
fill_height: bool = True,
|
83 |
+
delete_cache: tuple[int, int] | None = None,
|
84 |
+
):
|
85 |
+
super().__init__(
|
86 |
+
analytics_enabled=analytics_enabled,
|
87 |
+
mode="chat_interface",
|
88 |
+
css=css,
|
89 |
+
title=title or "Gradio",
|
90 |
+
theme=theme,
|
91 |
+
js=js,
|
92 |
+
head=head,
|
93 |
+
fill_height=fill_height,
|
94 |
+
delete_cache=delete_cache,
|
95 |
+
)
|
96 |
+
|
97 |
+
if post_fn_kwargs is None:
|
98 |
+
post_fn_kwargs = []
|
99 |
+
|
100 |
+
self.post_fn = post_fn
|
101 |
+
self.post_fn_kwargs = post_fn_kwargs
|
102 |
+
|
103 |
+
self.pre_fn = pre_fn
|
104 |
+
self.pre_fn_kwargs = pre_fn_kwargs
|
105 |
+
|
106 |
+
self.multimodal = multimodal
|
107 |
+
self.concurrency_limit = concurrency_limit
|
108 |
+
self.fn = fn
|
109 |
+
self.is_async = inspect.iscoroutinefunction(
|
110 |
+
self.fn
|
111 |
+
) or inspect.isasyncgenfunction(self.fn)
|
112 |
+
self.is_generator = inspect.isgeneratorfunction(
|
113 |
+
self.fn
|
114 |
+
) or inspect.isasyncgenfunction(self.fn)
|
115 |
+
|
116 |
+
if additional_inputs:
|
117 |
+
if not isinstance(additional_inputs, list):
|
118 |
+
additional_inputs = [additional_inputs]
|
119 |
+
self.additional_inputs = [
|
120 |
+
get_component_instance(i)
|
121 |
+
for i in additional_inputs # type: ignore
|
122 |
+
]
|
123 |
+
else:
|
124 |
+
self.additional_inputs = []
|
125 |
+
if additional_inputs_accordion_name is not None:
|
126 |
+
print(
|
127 |
+
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
128 |
+
)
|
129 |
+
self.additional_inputs_accordion_params = {
|
130 |
+
"label": additional_inputs_accordion_name
|
131 |
+
}
|
132 |
+
if additional_inputs_accordion is None:
|
133 |
+
self.additional_inputs_accordion_params = {
|
134 |
+
"label": "Additional Inputs",
|
135 |
+
"open": False,
|
136 |
+
}
|
137 |
+
elif isinstance(additional_inputs_accordion, str):
|
138 |
+
self.additional_inputs_accordion_params = {
|
139 |
+
"label": additional_inputs_accordion
|
140 |
+
}
|
141 |
+
elif isinstance(additional_inputs_accordion, Accordion):
|
142 |
+
self.additional_inputs_accordion_params = (
|
143 |
+
additional_inputs_accordion.recover_kwargs(
|
144 |
+
additional_inputs_accordion.get_config()
|
145 |
+
)
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise ValueError(
|
149 |
+
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
|
150 |
+
)
|
151 |
+
|
152 |
+
with self:
|
153 |
+
if title:
|
154 |
+
Markdown(
|
155 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
156 |
+
)
|
157 |
+
if description:
|
158 |
+
Markdown(description)
|
159 |
+
|
160 |
+
self.chatbot = chatbot.render()
|
161 |
+
|
162 |
+
self.buttons = [retry_btn, undo_btn, clear_btn]
|
163 |
+
|
164 |
+
with Group():
|
165 |
+
with Row():
|
166 |
+
if textbox:
|
167 |
+
if self.multimodal:
|
168 |
+
submit_btn = None
|
169 |
+
else:
|
170 |
+
textbox.container = False
|
171 |
+
textbox.show_label = False
|
172 |
+
textbox_ = textbox.render()
|
173 |
+
if not isinstance(textbox_, (Textbox, MultimodalTextbox)):
|
174 |
+
raise TypeError(
|
175 |
+
f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}"
|
176 |
+
)
|
177 |
+
self.textbox = textbox_
|
178 |
+
elif self.multimodal:
|
179 |
+
submit_btn = None
|
180 |
+
self.textbox = MultimodalTextbox(
|
181 |
+
show_label=False,
|
182 |
+
label="Message",
|
183 |
+
placeholder="Type a message...",
|
184 |
+
scale=7,
|
185 |
+
autofocus=autofocus,
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
self.textbox = Textbox(
|
189 |
+
container=False,
|
190 |
+
show_label=False,
|
191 |
+
label="Message",
|
192 |
+
placeholder="Type a message...",
|
193 |
+
scale=7,
|
194 |
+
autofocus=autofocus,
|
195 |
+
)
|
196 |
+
if submit_btn is not None and not multimodal:
|
197 |
+
if isinstance(submit_btn, Button):
|
198 |
+
submit_btn.render()
|
199 |
+
elif isinstance(submit_btn, str):
|
200 |
+
submit_btn = Button(
|
201 |
+
submit_btn,
|
202 |
+
variant="primary",
|
203 |
+
scale=1,
|
204 |
+
min_width=150,
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
raise ValueError(
|
208 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
209 |
+
)
|
210 |
+
if stop_btn is not None:
|
211 |
+
if isinstance(stop_btn, Button):
|
212 |
+
stop_btn.visible = False
|
213 |
+
stop_btn.render()
|
214 |
+
elif isinstance(stop_btn, str):
|
215 |
+
stop_btn = Button(
|
216 |
+
stop_btn,
|
217 |
+
variant="stop",
|
218 |
+
visible=False,
|
219 |
+
scale=1,
|
220 |
+
min_width=150,
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
raise ValueError(
|
224 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
225 |
+
)
|
226 |
+
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
|
227 |
+
|
228 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
229 |
+
self.fake_response_textbox = Textbox(label="Response", visible=False)
|
230 |
+
(
|
231 |
+
self.retry_btn,
|
232 |
+
self.undo_btn,
|
233 |
+
self.clear_btn,
|
234 |
+
self.submit_btn,
|
235 |
+
self.stop_btn,
|
236 |
+
) = self.buttons
|
237 |
+
|
238 |
+
any_unrendered_inputs = any(
|
239 |
+
not inp.is_rendered for inp in self.additional_inputs
|
240 |
+
)
|
241 |
+
if self.additional_inputs and any_unrendered_inputs:
|
242 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
243 |
+
for input_component in self.additional_inputs:
|
244 |
+
if not input_component.is_rendered:
|
245 |
+
input_component.render()
|
246 |
+
|
247 |
+
self.saved_input = State()
|
248 |
+
self.chatbot_state = (
|
249 |
+
State(self.chatbot.value) if self.chatbot.value else State([])
|
250 |
+
)
|
251 |
+
|
252 |
+
self._setup_events()
|
253 |
+
self._setup_api()
|
254 |
+
|
255 |
+
if examples:
|
256 |
+
examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False)
|
257 |
+
|
258 |
+
def _setup_events(self) -> None:
|
259 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
260 |
+
submit_triggers = (
|
261 |
+
[self.textbox.submit, self.submit_btn.click]
|
262 |
+
if self.submit_btn
|
263 |
+
else [self.textbox.submit]
|
264 |
+
)
|
265 |
+
submit_event = (
|
266 |
+
on(
|
267 |
+
submit_triggers,
|
268 |
+
self._clear_and_save_textbox,
|
269 |
+
[self.textbox],
|
270 |
+
[self.textbox, self.saved_input],
|
271 |
+
show_api=False,
|
272 |
+
queue=False,
|
273 |
+
)
|
274 |
+
.then(
|
275 |
+
self.pre_fn,
|
276 |
+
**self.pre_fn_kwargs,
|
277 |
+
show_api=False,
|
278 |
+
queue=False,
|
279 |
+
)
|
280 |
+
.then(
|
281 |
+
self._display_input,
|
282 |
+
[self.saved_input, self.chatbot_state],
|
283 |
+
[self.chatbot, self.chatbot_state],
|
284 |
+
show_api=False,
|
285 |
+
queue=False,
|
286 |
+
)
|
287 |
+
.then(
|
288 |
+
submit_fn,
|
289 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
290 |
+
[self.chatbot, self.chatbot_state],
|
291 |
+
show_api=False,
|
292 |
+
concurrency_limit=cast(
|
293 |
+
Union[int, Literal["default"], None], self.concurrency_limit
|
294 |
+
),
|
295 |
+
).then(
|
296 |
+
self.post_fn,
|
297 |
+
**self.post_fn_kwargs,
|
298 |
+
show_api=False,
|
299 |
+
concurrency_limit=cast(
|
300 |
+
Union[int, Literal["default"], None], self.concurrency_limit
|
301 |
+
),
|
302 |
+
)
|
303 |
+
)
|
304 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
305 |
+
|
306 |
+
if self.retry_btn:
|
307 |
+
retry_event = (
|
308 |
+
self.retry_btn.click(
|
309 |
+
self._delete_prev_fn,
|
310 |
+
[self.saved_input, self.chatbot_state],
|
311 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
312 |
+
show_api=False,
|
313 |
+
queue=False,
|
314 |
+
)
|
315 |
+
.then(
|
316 |
+
self.pre_fn,
|
317 |
+
**self.pre_fn_kwargs,
|
318 |
+
show_api=False,
|
319 |
+
queue=False,
|
320 |
+
)
|
321 |
+
.then(
|
322 |
+
self._display_input,
|
323 |
+
[self.saved_input, self.chatbot_state],
|
324 |
+
[self.chatbot, self.chatbot_state],
|
325 |
+
show_api=False,
|
326 |
+
queue=False,
|
327 |
+
)
|
328 |
+
.then(
|
329 |
+
submit_fn,
|
330 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
331 |
+
[self.chatbot, self.chatbot_state],
|
332 |
+
show_api=False,
|
333 |
+
concurrency_limit=cast(
|
334 |
+
Union[int, Literal["default"], None], self.concurrency_limit
|
335 |
+
),
|
336 |
+
).then(
|
337 |
+
self.post_fn,
|
338 |
+
**self.post_fn_kwargs,
|
339 |
+
show_api=False,
|
340 |
+
concurrency_limit=cast(
|
341 |
+
Union[int, Literal["default"], None], self.concurrency_limit
|
342 |
+
),
|
343 |
+
)
|
344 |
+
)
|
345 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
346 |
+
|
347 |
+
if self.undo_btn:
|
348 |
+
self.undo_btn.click(
|
349 |
+
self._delete_prev_fn,
|
350 |
+
[self.saved_input, self.chatbot_state],
|
351 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
352 |
+
show_api=False,
|
353 |
+
queue=False,
|
354 |
+
).then(
|
355 |
+
self.pre_fn,
|
356 |
+
**self.pre_fn_kwargs,
|
357 |
+
show_api=False,
|
358 |
+
queue=False,
|
359 |
+
).then(
|
360 |
+
async_lambda(lambda x: x),
|
361 |
+
[self.saved_input],
|
362 |
+
[self.textbox],
|
363 |
+
show_api=False,
|
364 |
+
queue=False,
|
365 |
+
).then(
|
366 |
+
self.post_fn,
|
367 |
+
**self.post_fn_kwargs,
|
368 |
+
show_api=False,
|
369 |
+
concurrency_limit=cast(
|
370 |
+
Union[int, Literal["default"], None], self.concurrency_limit
|
371 |
+
),
|
372 |
+
)
|
373 |
+
|
374 |
+
if self.clear_btn:
|
375 |
+
self.clear_btn.click(
|
376 |
+
async_lambda(lambda: ([], [], None)),
|
377 |
+
None,
|
378 |
+
[self.chatbot, self.chatbot_state, self.saved_input],
|
379 |
+
queue=False,
|
380 |
+
show_api=False,
|
381 |
+
).then(
|
382 |
+
self.pre_fn,
|
383 |
+
**self.pre_fn_kwargs,
|
384 |
+
show_api=False,
|
385 |
+
queue=False,
|
386 |
+
).then(
|
387 |
+
self.post_fn,
|
388 |
+
**self.post_fn_kwargs,
|
389 |
+
show_api=False,
|
390 |
+
concurrency_limit=cast(
|
391 |
+
Union[int, Literal["default"], None], self.concurrency_limit
|
392 |
+
),
|
393 |
+
)
|
394 |
+
|
395 |
+
def _setup_stop_events(
|
396 |
+
self, event_triggers: list[Callable], event_to_cancel: Dependency
|
397 |
+
) -> None:
|
398 |
+
if self.stop_btn and self.is_generator:
|
399 |
+
if self.submit_btn:
|
400 |
+
for event_trigger in event_triggers:
|
401 |
+
event_trigger(
|
402 |
+
async_lambda(
|
403 |
+
lambda: (
|
404 |
+
Button(visible=False),
|
405 |
+
Button(visible=True),
|
406 |
+
)
|
407 |
+
),
|
408 |
+
None,
|
409 |
+
[self.submit_btn, self.stop_btn],
|
410 |
+
show_api=False,
|
411 |
+
queue=False,
|
412 |
+
)
|
413 |
+
event_to_cancel.then(
|
414 |
+
async_lambda(lambda: (Button(visible=True), Button(visible=False))),
|
415 |
+
None,
|
416 |
+
[self.submit_btn, self.stop_btn],
|
417 |
+
show_api=False,
|
418 |
+
queue=False,
|
419 |
+
)
|
420 |
+
else:
|
421 |
+
for event_trigger in event_triggers:
|
422 |
+
event_trigger(
|
423 |
+
async_lambda(lambda: Button(visible=True)),
|
424 |
+
None,
|
425 |
+
[self.stop_btn],
|
426 |
+
show_api=False,
|
427 |
+
queue=False,
|
428 |
+
)
|
429 |
+
event_to_cancel.then(
|
430 |
+
async_lambda(lambda: Button(visible=False)),
|
431 |
+
None,
|
432 |
+
[self.stop_btn],
|
433 |
+
show_api=False,
|
434 |
+
queue=False,
|
435 |
+
)
|
436 |
+
self.stop_btn.click(
|
437 |
+
None,
|
438 |
+
None,
|
439 |
+
None,
|
440 |
+
cancels=event_to_cancel,
|
441 |
+
show_api=False,
|
442 |
+
)
|
443 |
+
|
444 |
+
def _setup_api(self) -> None:
|
445 |
+
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
446 |
+
|
447 |
+
self.fake_api_btn.click(
|
448 |
+
api_fn,
|
449 |
+
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
450 |
+
[self.textbox, self.chatbot_state],
|
451 |
+
api_name="chat",
|
452 |
+
concurrency_limit=cast(
|
453 |
+
Union[int, Literal["default"], None], self.concurrency_limit
|
454 |
+
),
|
455 |
+
)
|
456 |
+
|
457 |
+
def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]:
|
458 |
+
if self.multimodal:
|
459 |
+
return {"text": "", "files": []}, message
|
460 |
+
else:
|
461 |
+
return "", message
|
462 |
+
|
463 |
+
def _append_multimodal_history(
|
464 |
+
self,
|
465 |
+
message: dict[str, list],
|
466 |
+
response: str | None,
|
467 |
+
history: list[list[str | tuple | None]],
|
468 |
+
):
|
469 |
+
for x in message["files"]:
|
470 |
+
history.append([(x,), None])
|
471 |
+
if message["text"] is None or not isinstance(message["text"], str):
|
472 |
+
return
|
473 |
+
elif message["text"] == "" and message["files"] != []:
|
474 |
+
history.append([None, response])
|
475 |
+
else:
|
476 |
+
history.append([message["text"], response])
|
477 |
+
|
478 |
+
async def _display_input(
|
479 |
+
self, message: str | dict[str, list], history: list[list[str | tuple | None]]
|
480 |
+
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
481 |
+
if self.multimodal and isinstance(message, dict):
|
482 |
+
self._append_multimodal_history(message, None, history)
|
483 |
+
elif isinstance(message, str):
|
484 |
+
history.append([message, None])
|
485 |
+
return history, history
|
486 |
+
|
487 |
+
async def _submit_fn(
|
488 |
+
self,
|
489 |
+
message: str | dict[str, list],
|
490 |
+
history_with_input: list[list[str | tuple | None]],
|
491 |
+
request: Request,
|
492 |
+
*args,
|
493 |
+
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
494 |
+
if self.multimodal and isinstance(message, dict):
|
495 |
+
remove_input = (
|
496 |
+
len(message["files"]) + 1
|
497 |
+
if message["text"] is not None
|
498 |
+
else len(message["files"])
|
499 |
+
)
|
500 |
+
history = history_with_input[:-remove_input]
|
501 |
+
else:
|
502 |
+
history = history_with_input[:-1]
|
503 |
+
inputs, _, _ = special_args(
|
504 |
+
self.fn, inputs=[message, history, *args], request=request
|
505 |
+
)
|
506 |
+
|
507 |
+
if self.is_async:
|
508 |
+
response = await self.fn(*inputs)
|
509 |
+
else:
|
510 |
+
response = await anyio.to_thread.run_sync(
|
511 |
+
self.fn, *inputs, limiter=self.limiter
|
512 |
+
)
|
513 |
+
|
514 |
+
if self.multimodal and isinstance(message, dict):
|
515 |
+
self._append_multimodal_history(message, response, history)
|
516 |
+
elif isinstance(message, str):
|
517 |
+
history.append([message, response])
|
518 |
+
return history, history
|
519 |
+
|
520 |
+
async def _stream_fn(
|
521 |
+
self,
|
522 |
+
message: str | dict[str, list],
|
523 |
+
history_with_input: list[list[str | tuple | None]],
|
524 |
+
request: Request,
|
525 |
+
*args,
|
526 |
+
) -> AsyncGenerator:
|
527 |
+
if self.multimodal and isinstance(message, dict):
|
528 |
+
remove_input = (
|
529 |
+
len(message["files"]) + 1
|
530 |
+
if message["text"] is not None
|
531 |
+
else len(message["files"])
|
532 |
+
)
|
533 |
+
history = history_with_input[:-remove_input]
|
534 |
+
else:
|
535 |
+
history = history_with_input[:-1]
|
536 |
+
inputs, _, _ = special_args(
|
537 |
+
self.fn, inputs=[message, history, *args], request=request
|
538 |
+
)
|
539 |
+
|
540 |
+
if self.is_async:
|
541 |
+
generator = self.fn(*inputs)
|
542 |
+
else:
|
543 |
+
generator = await anyio.to_thread.run_sync(
|
544 |
+
self.fn, *inputs, limiter=self.limiter
|
545 |
+
)
|
546 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
547 |
+
try:
|
548 |
+
first_response = await async_iteration(generator)
|
549 |
+
if self.multimodal and isinstance(message, dict):
|
550 |
+
for x in message["files"]:
|
551 |
+
history.append([(x,), None])
|
552 |
+
update = history + [[message["text"], first_response]]
|
553 |
+
yield update, update
|
554 |
+
else:
|
555 |
+
update = history + [[message, first_response]]
|
556 |
+
yield update, update
|
557 |
+
except StopIteration:
|
558 |
+
if self.multimodal and isinstance(message, dict):
|
559 |
+
self._append_multimodal_history(message, None, history)
|
560 |
+
yield history, history
|
561 |
+
else:
|
562 |
+
update = history + [[message, None]]
|
563 |
+
yield update, update
|
564 |
+
async for response in generator:
|
565 |
+
if self.multimodal and isinstance(message, dict):
|
566 |
+
update = history + [[message["text"], response]]
|
567 |
+
yield update, update
|
568 |
+
else:
|
569 |
+
update = history + [[message, response]]
|
570 |
+
yield update, update
|
571 |
+
|
572 |
+
async def _api_submit_fn(
|
573 |
+
self, message: str, history: list[list[str | None]], request: Request, *args
|
574 |
+
) -> tuple[str, list[list[str | None]]]:
|
575 |
+
inputs, _, _ = special_args(
|
576 |
+
self.fn, inputs=[message, history, *args], request=request
|
577 |
+
)
|
578 |
+
|
579 |
+
if self.is_async:
|
580 |
+
response = await self.fn(*inputs)
|
581 |
+
else:
|
582 |
+
response = await anyio.to_thread.run_sync(
|
583 |
+
self.fn, *inputs, limiter=self.limiter
|
584 |
+
)
|
585 |
+
history.append([message, response])
|
586 |
+
return response, history
|
587 |
+
|
588 |
+
async def _api_stream_fn(
|
589 |
+
self, message: str, history: list[list[str | None]], request: Request, *args
|
590 |
+
) -> AsyncGenerator:
|
591 |
+
inputs, _, _ = special_args(
|
592 |
+
self.fn, inputs=[message, history, *args], request=request
|
593 |
+
)
|
594 |
+
|
595 |
+
if self.is_async:
|
596 |
+
generator = self.fn(*inputs)
|
597 |
+
else:
|
598 |
+
generator = await anyio.to_thread.run_sync(
|
599 |
+
self.fn, *inputs, limiter=self.limiter
|
600 |
+
)
|
601 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
602 |
+
try:
|
603 |
+
first_response = await async_iteration(generator)
|
604 |
+
yield first_response, history + [[message, first_response]]
|
605 |
+
except StopIteration:
|
606 |
+
yield None, history + [[message, None]]
|
607 |
+
async for response in generator:
|
608 |
+
yield response, history + [[message, response]]
|
609 |
+
|
610 |
+
async def _delete_prev_fn(
|
611 |
+
self,
|
612 |
+
message: str | dict[str, list],
|
613 |
+
history: list[list[str | tuple | None]],
|
614 |
+
) -> tuple[
|
615 |
+
list[list[str | tuple | None]],
|
616 |
+
str | dict[str, list],
|
617 |
+
list[list[str | tuple | None]],
|
618 |
+
]:
|
619 |
+
if self.multimodal and isinstance(message, dict):
|
620 |
+
remove_input = (
|
621 |
+
len(message["files"]) + 1
|
622 |
+
if message["text"] is not None
|
623 |
+
else len(message["files"])
|
624 |
+
)
|
625 |
+
history = history[:-remove_input]
|
626 |
+
else:
|
627 |
+
history = history[:-1]
|
628 |
+
return history, message or "", history
|
lib_omost/canvas.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import difflib
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
system_prompt = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
|
6 |
+
|
7 |
+
```python
|
8 |
+
class Canvas:
|
9 |
+
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
|
13 |
+
assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
|
14 |
+
assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
|
15 |
+
assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
|
16 |
+
assert distance_to_viewer > 0
|
17 |
+
pass
|
18 |
+
```'''
|
19 |
+
|
20 |
+
valid_colors = { # r, g, b
|
21 |
+
'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
|
22 |
+
'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
|
23 |
+
'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
|
24 |
+
'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
|
25 |
+
'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
|
26 |
+
'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
|
27 |
+
'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
|
28 |
+
'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
|
29 |
+
'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
|
30 |
+
'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
|
31 |
+
'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
|
32 |
+
'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
|
33 |
+
'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
|
34 |
+
'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
|
35 |
+
'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
|
36 |
+
'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
|
37 |
+
'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
|
38 |
+
'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
|
39 |
+
'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
|
40 |
+
'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
|
41 |
+
'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
|
42 |
+
'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
|
43 |
+
'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
|
44 |
+
'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
|
45 |
+
'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
|
46 |
+
'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
|
47 |
+
'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
|
48 |
+
'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
|
49 |
+
'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
|
50 |
+
'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
|
51 |
+
'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
|
52 |
+
'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
|
53 |
+
'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
|
54 |
+
'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
|
55 |
+
'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
|
56 |
+
'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
|
57 |
+
'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
|
58 |
+
'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
|
59 |
+
'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
|
60 |
+
'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
|
61 |
+
'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
|
62 |
+
'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
|
63 |
+
'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
|
64 |
+
'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
|
65 |
+
'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
|
66 |
+
'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
|
67 |
+
'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
|
68 |
+
'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
|
69 |
+
}
|
70 |
+
|
71 |
+
valid_locations = { # x, y in 90*90
|
72 |
+
'in the center': (45, 45),
|
73 |
+
'on the left': (15, 45),
|
74 |
+
'on the right': (75, 45),
|
75 |
+
'on the top': (45, 15),
|
76 |
+
'on the bottom': (45, 75),
|
77 |
+
'on the top-left': (15, 15),
|
78 |
+
'on the top-right': (75, 15),
|
79 |
+
'on the bottom-left': (15, 75),
|
80 |
+
'on the bottom-right': (75, 75)
|
81 |
+
}
|
82 |
+
|
83 |
+
valid_offsets = { # x, y in 90*90
|
84 |
+
'no offset': (0, 0),
|
85 |
+
'slightly to the left': (-10, 0),
|
86 |
+
'slightly to the right': (10, 0),
|
87 |
+
'slightly to the upper': (0, -10),
|
88 |
+
'slightly to the lower': (0, 10),
|
89 |
+
'slightly to the upper-left': (-10, -10),
|
90 |
+
'slightly to the upper-right': (10, -10),
|
91 |
+
'slightly to the lower-left': (-10, 10),
|
92 |
+
'slightly to the lower-right': (10, 10)}
|
93 |
+
|
94 |
+
valid_areas = { # w, h in 90*90
|
95 |
+
"a small square area": (50, 50),
|
96 |
+
"a small vertical area": (40, 60),
|
97 |
+
"a small horizontal area": (60, 40),
|
98 |
+
"a medium-sized square area": (60, 60),
|
99 |
+
"a medium-sized vertical area": (50, 80),
|
100 |
+
"a medium-sized horizontal area": (80, 50),
|
101 |
+
"a large square area": (70, 70),
|
102 |
+
"a large vertical area": (60, 90),
|
103 |
+
"a large horizontal area": (90, 60)
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def closest_name(input_str, options):
|
108 |
+
input_str = input_str.lower()
|
109 |
+
|
110 |
+
closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
|
111 |
+
assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
|
112 |
+
result = closest_match[0]
|
113 |
+
|
114 |
+
if result != input_str:
|
115 |
+
print(f'Automatically corrected [{input_str}] -> [{result}].')
|
116 |
+
|
117 |
+
return result
|
118 |
+
|
119 |
+
|
120 |
+
def safe_str(x):
|
121 |
+
return x.strip(',. ') + '.'
|
122 |
+
|
123 |
+
|
124 |
+
def binary_nonzero_positions(n, offset=0):
|
125 |
+
binary_str = bin(n)[2:]
|
126 |
+
positions = [i + offset for i, bit in enumerate(reversed(binary_str)) if bit == '1']
|
127 |
+
return positions
|
128 |
+
|
129 |
+
|
130 |
+
class Canvas:
|
131 |
+
@staticmethod
|
132 |
+
def from_bot_response(response: str):
|
133 |
+
matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
|
134 |
+
assert matched, 'Response does not contain codes!'
|
135 |
+
code_content = matched.group(1)
|
136 |
+
assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
|
137 |
+
local_vars = {'Canvas': Canvas}
|
138 |
+
exec(code_content, {}, local_vars)
|
139 |
+
canvas = local_vars.get('canvas', None)
|
140 |
+
assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
|
141 |
+
return canvas
|
142 |
+
|
143 |
+
def __init__(self):
|
144 |
+
self.components = []
|
145 |
+
self.color = None
|
146 |
+
self.record_tags = True
|
147 |
+
self.prefixes = []
|
148 |
+
self.suffixes = []
|
149 |
+
return
|
150 |
+
|
151 |
+
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
|
152 |
+
HTML_web_color_name: str):
|
153 |
+
assert isinstance(description, str), 'Global description is not valid!'
|
154 |
+
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
|
155 |
+
'Global detailed_descriptions is not valid!'
|
156 |
+
assert isinstance(tags, str), 'Global tags is not valid!'
|
157 |
+
|
158 |
+
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
|
159 |
+
self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
|
160 |
+
|
161 |
+
self.prefixes = [description]
|
162 |
+
self.suffixes = detailed_descriptions
|
163 |
+
|
164 |
+
if self.record_tags:
|
165 |
+
self.suffixes = self.suffixes + [tags]
|
166 |
+
|
167 |
+
self.prefixes = [safe_str(x) for x in self.prefixes]
|
168 |
+
self.suffixes = [safe_str(x) for x in self.suffixes]
|
169 |
+
|
170 |
+
return
|
171 |
+
|
172 |
+
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
|
173 |
+
detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
|
174 |
+
quality_meta: str, HTML_web_color_name: str):
|
175 |
+
assert isinstance(description, str), 'Local description is wrong!'
|
176 |
+
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
|
177 |
+
f'The distance_to_viewer for [{description}] is not positive float number!'
|
178 |
+
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
|
179 |
+
f'The detailed_descriptions for [{description}] is not valid!'
|
180 |
+
assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
|
181 |
+
assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
|
182 |
+
assert isinstance(style, str), f'The style for [{description}] is not valid!'
|
183 |
+
assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
|
184 |
+
|
185 |
+
location = closest_name(location, valid_locations)
|
186 |
+
offset = closest_name(offset, valid_offsets)
|
187 |
+
area = closest_name(area, valid_areas)
|
188 |
+
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
|
189 |
+
|
190 |
+
xb, yb = valid_locations[location]
|
191 |
+
xo, yo = valid_offsets[offset]
|
192 |
+
w, h = valid_areas[area]
|
193 |
+
rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
|
194 |
+
rect = [max(0, min(90, i)) for i in rect]
|
195 |
+
color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
|
196 |
+
|
197 |
+
prefixes = self.prefixes + [description]
|
198 |
+
suffixes = detailed_descriptions
|
199 |
+
|
200 |
+
if self.record_tags:
|
201 |
+
suffixes = suffixes + [tags, atmosphere, style, quality_meta]
|
202 |
+
|
203 |
+
prefixes = [safe_str(x) for x in prefixes]
|
204 |
+
suffixes = [safe_str(x) for x in suffixes]
|
205 |
+
|
206 |
+
self.components.append(dict(
|
207 |
+
rect=rect,
|
208 |
+
distance_to_viewer=distance_to_viewer,
|
209 |
+
color=color,
|
210 |
+
prefixes=prefixes,
|
211 |
+
suffixes=suffixes
|
212 |
+
))
|
213 |
+
|
214 |
+
return
|
215 |
+
|
216 |
+
def process(self):
|
217 |
+
# sort components
|
218 |
+
self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
|
219 |
+
|
220 |
+
# compute initial latent
|
221 |
+
initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
|
222 |
+
|
223 |
+
for component in self.components:
|
224 |
+
a, b, c, d = component['rect']
|
225 |
+
initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
|
226 |
+
|
227 |
+
initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
|
228 |
+
|
229 |
+
# compute conditions
|
230 |
+
|
231 |
+
bag_of_conditions = [
|
232 |
+
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
|
233 |
+
]
|
234 |
+
|
235 |
+
for i, component in enumerate(self.components):
|
236 |
+
a, b, c, d = component['rect']
|
237 |
+
m = np.zeros(shape=(90, 90), dtype=np.float32)
|
238 |
+
m[a:b, c:d] = 1.0
|
239 |
+
bag_of_conditions.append(dict(
|
240 |
+
mask=m,
|
241 |
+
prefixes=component['prefixes'],
|
242 |
+
suffixes=component['suffixes']
|
243 |
+
))
|
244 |
+
|
245 |
+
return dict(
|
246 |
+
initial_latent=initial_latent,
|
247 |
+
bag_of_conditions=bag_of_conditions,
|
248 |
+
)
|
lib_omost/pipeline.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import copy
|
3 |
+
|
4 |
+
from tqdm.auto import trange
|
5 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import *
|
6 |
+
from diffusers.models.transformers import Transformer2DModel
|
7 |
+
|
8 |
+
|
9 |
+
original_Transformer2DModel_forward = Transformer2DModel.forward
|
10 |
+
|
11 |
+
|
12 |
+
def hacked_Transformer2DModel_forward(
|
13 |
+
self,
|
14 |
+
hidden_states: torch.Tensor,
|
15 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
16 |
+
timestep: Optional[torch.LongTensor] = None,
|
17 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
18 |
+
class_labels: Optional[torch.LongTensor] = None,
|
19 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
20 |
+
attention_mask: Optional[torch.Tensor] = None,
|
21 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
22 |
+
return_dict: bool = True,
|
23 |
+
):
|
24 |
+
cross_attention_kwargs = cross_attention_kwargs or {}
|
25 |
+
cross_attention_kwargs['hidden_states_original_shape'] = hidden_states.shape
|
26 |
+
return original_Transformer2DModel_forward(
|
27 |
+
self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs,
|
28 |
+
attention_mask, encoder_attention_mask, return_dict)
|
29 |
+
|
30 |
+
|
31 |
+
Transformer2DModel.forward = hacked_Transformer2DModel_forward
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
36 |
+
"""DPM-Solver++(2M)."""
|
37 |
+
extra_args = {} if extra_args is None else extra_args
|
38 |
+
s_in = x.new_ones([x.shape[0]])
|
39 |
+
sigma_fn = lambda t: t.neg().exp()
|
40 |
+
t_fn = lambda sigma: sigma.log().neg()
|
41 |
+
old_denoised = None
|
42 |
+
|
43 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
44 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
45 |
+
if callback is not None:
|
46 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
47 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
48 |
+
h = t_next - t
|
49 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
50 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
51 |
+
else:
|
52 |
+
h_last = t - t_fn(sigmas[i - 1])
|
53 |
+
r = h_last / h
|
54 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
55 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
56 |
+
old_denoised = denoised
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class KModel:
|
61 |
+
def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012):
|
62 |
+
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
|
63 |
+
alphas = 1. - betas
|
64 |
+
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
65 |
+
|
66 |
+
self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
67 |
+
self.log_sigmas = self.sigmas.log()
|
68 |
+
self.sigma_data = 1.0
|
69 |
+
self.unet = unet
|
70 |
+
return
|
71 |
+
|
72 |
+
@property
|
73 |
+
def sigma_min(self):
|
74 |
+
return self.sigmas[0]
|
75 |
+
|
76 |
+
@property
|
77 |
+
def sigma_max(self):
|
78 |
+
return self.sigmas[-1]
|
79 |
+
|
80 |
+
def timestep(self, sigma):
|
81 |
+
log_sigma = sigma.log()
|
82 |
+
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
83 |
+
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
|
84 |
+
|
85 |
+
def get_sigmas_karras(self, n, rho=7.):
|
86 |
+
ramp = torch.linspace(0, 1, n)
|
87 |
+
min_inv_rho = self.sigma_min ** (1 / rho)
|
88 |
+
max_inv_rho = self.sigma_max ** (1 / rho)
|
89 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
90 |
+
return torch.cat([sigmas, sigmas.new_zeros([1])])
|
91 |
+
|
92 |
+
def __call__(self, x, sigma, **extra_args):
|
93 |
+
x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
|
94 |
+
t = self.timestep(sigma)
|
95 |
+
cfg_scale = extra_args['cfg_scale']
|
96 |
+
eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
|
97 |
+
eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
|
98 |
+
noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
|
99 |
+
return x - noise_pred * sigma[:, None, None, None]
|
100 |
+
|
101 |
+
|
102 |
+
class OmostSelfAttnProcessor:
|
103 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
|
104 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
105 |
+
|
106 |
+
query = attn.to_q(hidden_states)
|
107 |
+
key = attn.to_k(hidden_states)
|
108 |
+
value = attn.to_v(hidden_states)
|
109 |
+
|
110 |
+
inner_dim = key.shape[-1]
|
111 |
+
head_dim = inner_dim // attn.heads
|
112 |
+
|
113 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
114 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
115 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
116 |
+
|
117 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
118 |
+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
119 |
+
)
|
120 |
+
|
121 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
122 |
+
hidden_states = hidden_states.to(query.dtype)
|
123 |
+
hidden_states = attn.to_out[0](hidden_states)
|
124 |
+
hidden_states = attn.to_out[1](hidden_states)
|
125 |
+
|
126 |
+
return hidden_states
|
127 |
+
|
128 |
+
|
129 |
+
class OmostCrossAttnProcessor:
|
130 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
|
131 |
+
B, C, H, W = hidden_states_original_shape
|
132 |
+
|
133 |
+
conds = []
|
134 |
+
masks = []
|
135 |
+
|
136 |
+
for m, c in encoder_hidden_states:
|
137 |
+
m = torch.nn.functional.interpolate(m[None, None, :, :], (H, W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, c.size(1))
|
138 |
+
conds.append(c)
|
139 |
+
masks.append(m)
|
140 |
+
|
141 |
+
conds = torch.cat(conds, dim=1)
|
142 |
+
masks = torch.cat(masks, dim=1)
|
143 |
+
|
144 |
+
mask_bool = masks > 0.5
|
145 |
+
mask_scale = (H * W) / torch.sum(masks, dim=0, keepdim=True)
|
146 |
+
|
147 |
+
batch_size, sequence_length, _ = conds.shape
|
148 |
+
|
149 |
+
query = attn.to_q(hidden_states)
|
150 |
+
key = attn.to_k(conds)
|
151 |
+
value = attn.to_v(conds)
|
152 |
+
|
153 |
+
inner_dim = key.shape[-1]
|
154 |
+
head_dim = inner_dim // attn.heads
|
155 |
+
|
156 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
157 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
158 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
159 |
+
|
160 |
+
mask_bool = mask_bool[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
|
161 |
+
mask_scale = mask_scale[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
|
162 |
+
|
163 |
+
sim = query @ key.transpose(-2, -1) * attn.scale
|
164 |
+
sim = sim * mask_scale.to(sim)
|
165 |
+
sim.masked_fill_(mask_bool.logical_not(), float("-inf"))
|
166 |
+
sim = sim.softmax(dim=-1)
|
167 |
+
|
168 |
+
h = sim @ value
|
169 |
+
h = h.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
170 |
+
|
171 |
+
h = attn.to_out[0](h)
|
172 |
+
h = attn.to_out[1](h)
|
173 |
+
return h
|
174 |
+
|
175 |
+
|
176 |
+
class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
|
177 |
+
def __init__(self, *args, **kwargs):
|
178 |
+
super().__init__(*args, **kwargs)
|
179 |
+
self.k_model = KModel(unet=self.unet)
|
180 |
+
|
181 |
+
attn_procs = {}
|
182 |
+
for name in self.unet.attn_processors.keys():
|
183 |
+
if name.endswith("attn2.processor"):
|
184 |
+
attn_procs[name] = OmostCrossAttnProcessor()
|
185 |
+
else:
|
186 |
+
attn_procs[name] = OmostSelfAttnProcessor()
|
187 |
+
|
188 |
+
self.unet.set_attn_processor(attn_procs)
|
189 |
+
return
|
190 |
+
|
191 |
+
@torch.inference_mode()
|
192 |
+
def encode_bag_of_subprompts_greedy(self, prefixes: list[str], suffixes: list[str]):
|
193 |
+
device = self.text_encoder.device
|
194 |
+
|
195 |
+
@torch.inference_mode()
|
196 |
+
def greedy_partition(items, max_sum):
|
197 |
+
bags = []
|
198 |
+
current_bag = []
|
199 |
+
current_sum = 0
|
200 |
+
|
201 |
+
for item in items:
|
202 |
+
num = item['length']
|
203 |
+
if current_sum + num > max_sum:
|
204 |
+
if current_bag:
|
205 |
+
bags.append(current_bag)
|
206 |
+
current_bag = [item]
|
207 |
+
current_sum = num
|
208 |
+
else:
|
209 |
+
current_bag.append(item)
|
210 |
+
current_sum += num
|
211 |
+
|
212 |
+
if current_bag:
|
213 |
+
bags.append(current_bag)
|
214 |
+
|
215 |
+
return bags
|
216 |
+
|
217 |
+
@torch.inference_mode()
|
218 |
+
def get_77_tokens_in_torch(subprompt_inds, tokenizer):
|
219 |
+
# Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
|
220 |
+
result = [tokenizer.bos_token_id] + subprompt_inds[:75] + [tokenizer.eos_token_id] + [tokenizer.pad_token_id] * 75
|
221 |
+
result = result[:77]
|
222 |
+
result = torch.tensor([result]).to(device=device, dtype=torch.int64)
|
223 |
+
return result
|
224 |
+
|
225 |
+
@torch.inference_mode()
|
226 |
+
def merge_with_prefix(bag):
|
227 |
+
merged_ids_t1 = copy.deepcopy(prefix_ids_t1)
|
228 |
+
merged_ids_t2 = copy.deepcopy(prefix_ids_t2)
|
229 |
+
|
230 |
+
for item in bag:
|
231 |
+
merged_ids_t1.extend(item['ids_t1'])
|
232 |
+
merged_ids_t2.extend(item['ids_t2'])
|
233 |
+
|
234 |
+
return dict(
|
235 |
+
ids_t1=get_77_tokens_in_torch(merged_ids_t1, self.tokenizer),
|
236 |
+
ids_t2=get_77_tokens_in_torch(merged_ids_t2, self.tokenizer_2)
|
237 |
+
)
|
238 |
+
|
239 |
+
@torch.inference_mode()
|
240 |
+
def double_encode(pair_of_inds):
|
241 |
+
inds = [pair_of_inds['ids_t1'], pair_of_inds['ids_t2']]
|
242 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
243 |
+
|
244 |
+
pooled_prompt_embeds = None
|
245 |
+
prompt_embeds_list = []
|
246 |
+
|
247 |
+
for text_input_ids, text_encoder in zip(inds, text_encoders):
|
248 |
+
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)
|
249 |
+
|
250 |
+
# Only last pooler_output is needed
|
251 |
+
pooled_prompt_embeds = prompt_embeds.pooler_output
|
252 |
+
|
253 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
254 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
255 |
+
prompt_embeds_list.append(prompt_embeds)
|
256 |
+
|
257 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
258 |
+
return prompt_embeds, pooled_prompt_embeds
|
259 |
+
|
260 |
+
# Begin with tokenizing prefixes
|
261 |
+
|
262 |
+
prefix_length = 0
|
263 |
+
prefix_ids_t1 = []
|
264 |
+
prefix_ids_t2 = []
|
265 |
+
|
266 |
+
for prefix in prefixes:
|
267 |
+
ids_t1 = self.tokenizer(prefix, truncation=False, add_special_tokens=False).input_ids
|
268 |
+
ids_t2 = self.tokenizer_2(prefix, truncation=False, add_special_tokens=False).input_ids
|
269 |
+
assert len(ids_t1) == len(ids_t2)
|
270 |
+
prefix_length += len(ids_t1)
|
271 |
+
prefix_ids_t1 += ids_t1
|
272 |
+
prefix_ids_t2 += ids_t2
|
273 |
+
|
274 |
+
# Then tokenizing suffixes
|
275 |
+
|
276 |
+
allowed_suffix_length = 75 - prefix_length
|
277 |
+
suffix_targets = []
|
278 |
+
|
279 |
+
for subprompt in suffixes:
|
280 |
+
# Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
|
281 |
+
# So we can safely just crop it to 75
|
282 |
+
ids_t1 = self.tokenizer(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
|
283 |
+
ids_t2 = self.tokenizer_2(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
|
284 |
+
assert len(ids_t1) == len(ids_t2)
|
285 |
+
suffix_targets.append(dict(
|
286 |
+
length=len(ids_t1),
|
287 |
+
ids_t1=ids_t1,
|
288 |
+
ids_t2=ids_t2
|
289 |
+
))
|
290 |
+
|
291 |
+
# Then merge prefix and suffix tokens
|
292 |
+
|
293 |
+
suffix_targets = greedy_partition(suffix_targets, max_sum=allowed_suffix_length)
|
294 |
+
targets = [merge_with_prefix(b) for b in suffix_targets]
|
295 |
+
|
296 |
+
# Encode!
|
297 |
+
|
298 |
+
conds, poolers = [], []
|
299 |
+
|
300 |
+
for target in targets:
|
301 |
+
cond, pooler = double_encode(target)
|
302 |
+
conds.append(cond)
|
303 |
+
poolers.append(pooler)
|
304 |
+
|
305 |
+
conds_merged = torch.concat(conds, dim=1)
|
306 |
+
poolers_merged = poolers[0]
|
307 |
+
|
308 |
+
return dict(cond=conds_merged, pooler=poolers_merged)
|
309 |
+
|
310 |
+
@torch.inference_mode()
|
311 |
+
def all_conds_from_canvas(self, canvas_outputs, negative_prompt):
|
312 |
+
mask_all = torch.ones(size=(90, 90), dtype=torch.float32)
|
313 |
+
negative_cond, negative_pooler = self.encode_cropped_prompt_77tokens(negative_prompt)
|
314 |
+
negative_result = [(mask_all, negative_cond)]
|
315 |
+
|
316 |
+
positive_result = []
|
317 |
+
positive_pooler = None
|
318 |
+
|
319 |
+
for item in canvas_outputs['bag_of_conditions']:
|
320 |
+
current_mask = torch.from_numpy(item['mask']).to(torch.float32)
|
321 |
+
current_prefixes = item['prefixes']
|
322 |
+
current_suffixes = item['suffixes']
|
323 |
+
|
324 |
+
current_cond = self.encode_bag_of_subprompts_greedy(prefixes=current_prefixes, suffixes=current_suffixes)
|
325 |
+
|
326 |
+
if positive_pooler is None:
|
327 |
+
positive_pooler = current_cond['pooler']
|
328 |
+
|
329 |
+
positive_result.append((current_mask, current_cond['cond']))
|
330 |
+
|
331 |
+
return positive_result, positive_pooler, negative_result, negative_pooler
|
332 |
+
|
333 |
+
@torch.inference_mode()
|
334 |
+
def encode_cropped_prompt_77tokens(self, prompt: str):
|
335 |
+
device = self.text_encoder.device
|
336 |
+
tokenizers = [self.tokenizer, self.tokenizer_2]
|
337 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
338 |
+
|
339 |
+
pooled_prompt_embeds = None
|
340 |
+
prompt_embeds_list = []
|
341 |
+
|
342 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
343 |
+
text_input_ids = tokenizer(
|
344 |
+
prompt,
|
345 |
+
padding="max_length",
|
346 |
+
max_length=tokenizer.model_max_length,
|
347 |
+
truncation=True,
|
348 |
+
return_tensors="pt",
|
349 |
+
).input_ids
|
350 |
+
|
351 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
352 |
+
|
353 |
+
# Only last pooler_output is needed
|
354 |
+
pooled_prompt_embeds = prompt_embeds.pooler_output
|
355 |
+
|
356 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
357 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
358 |
+
prompt_embeds_list.append(prompt_embeds)
|
359 |
+
|
360 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
361 |
+
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
362 |
+
|
363 |
+
return prompt_embeds, pooled_prompt_embeds
|
364 |
+
|
365 |
+
@torch.inference_mode()
|
366 |
+
def __call__(
|
367 |
+
self,
|
368 |
+
initial_latent: torch.FloatTensor = None,
|
369 |
+
strength: float = 1.0,
|
370 |
+
num_inference_steps: int = 25,
|
371 |
+
guidance_scale: float = 5.0,
|
372 |
+
batch_size: Optional[int] = 1,
|
373 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
374 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
375 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
376 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
377 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
378 |
+
cross_attention_kwargs: Optional[dict] = None,
|
379 |
+
):
|
380 |
+
|
381 |
+
device = self.unet.device
|
382 |
+
cross_attention_kwargs = cross_attention_kwargs or {}
|
383 |
+
|
384 |
+
# Sigmas
|
385 |
+
|
386 |
+
sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps / strength))
|
387 |
+
sigmas = sigmas[-(num_inference_steps + 1):].to(device)
|
388 |
+
|
389 |
+
# Initial latents
|
390 |
+
|
391 |
+
_, C, H, W = initial_latent.shape
|
392 |
+
noise = randn_tensor((batch_size, C, H, W), generator=generator, device=device, dtype=self.unet.dtype)
|
393 |
+
latents = initial_latent.to(noise) + noise * sigmas[0].to(noise)
|
394 |
+
|
395 |
+
# Shape
|
396 |
+
|
397 |
+
height, width = latents.shape[-2:]
|
398 |
+
height = height * self.vae_scale_factor
|
399 |
+
width = width * self.vae_scale_factor
|
400 |
+
|
401 |
+
add_time_ids = list((height, width) + (0, 0) + (height, width))
|
402 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=self.unet.dtype)
|
403 |
+
add_neg_time_ids = add_time_ids.clone()
|
404 |
+
|
405 |
+
# Batch
|
406 |
+
|
407 |
+
latents = latents.to(device)
|
408 |
+
add_time_ids = add_time_ids.repeat(batch_size, 1).to(device)
|
409 |
+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device)
|
410 |
+
prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in prompt_embeds]
|
411 |
+
negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in negative_prompt_embeds]
|
412 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
|
413 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
|
414 |
+
|
415 |
+
# Feeds
|
416 |
+
|
417 |
+
sampler_kwargs = dict(
|
418 |
+
cfg_scale=guidance_scale,
|
419 |
+
positive=dict(
|
420 |
+
encoder_hidden_states=prompt_embeds,
|
421 |
+
added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
|
422 |
+
cross_attention_kwargs=cross_attention_kwargs
|
423 |
+
),
|
424 |
+
negative=dict(
|
425 |
+
encoder_hidden_states=negative_prompt_embeds,
|
426 |
+
added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids},
|
427 |
+
cross_attention_kwargs=cross_attention_kwargs
|
428 |
+
)
|
429 |
+
)
|
430 |
+
|
431 |
+
# Sample
|
432 |
+
|
433 |
+
results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False)
|
434 |
+
|
435 |
+
return StableDiffusionXLPipelineOutput(images=results)
|