Menyu commited on
Commit
0d23237
·
verified ·
1 Parent(s): 0813976

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -14
app.py CHANGED
@@ -3,9 +3,221 @@ import gradio as gr
3
  import numpy as np
4
  import spaces
5
  import torch
6
- from diffusers import AutoPipelineForText2Image, AutoencoderKL #,EulerDiscreteScheduler
7
  from compel import Compel, ReturnedEmbeddingsType
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  if not torch.cuda.is_available():
10
  DESCRIPTION += "\n<p>你现在运行在CPU上 但是此项目只支持GPU.</p>"
11
 
@@ -14,8 +226,6 @@ MAX_IMAGE_SIZE = 4096
14
 
15
  if torch.cuda.is_available():
16
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
17
- #vae = AutoencoderKL.from_pretrained("https://huggingface.co/scrapware/personal-backup/resolve/main/bakedvae/anyloraCheckpoint_bakedvaeTwinkleFp16.safetensors", torch_dtype=torch.float16)
18
-
19
  pipe = AutoPipelineForText2Image.from_pretrained(
20
  "John6666/noobai-xl-nai-xl-epsilonpred10version-sdxl",
21
  vae=vae,
@@ -23,7 +233,6 @@ if torch.cuda.is_available():
23
  use_safetensors=True,
24
  add_watermarker=False
25
  )
26
- #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
27
  pipe.to("cuda")
28
 
29
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
@@ -47,14 +256,27 @@ def infer(
47
  ):
48
  seed = int(randomize_seed_fn(seed, randomize_seed))
49
  generator = torch.Generator().manual_seed(seed)
50
- compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
51
- conditioning, pooled = compel(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
53
  image = pipe(
54
- #prompt=prompt,
55
- prompt_embeds=conditioning,
56
- pooled_prompt_embeds=pooled,
57
- negative_prompt=negative_prompt,
58
  width=width,
59
  height=height,
60
  guidance_scale=guidance_scale,
@@ -76,10 +298,10 @@ footer {
76
  visibility: hidden
77
  }
78
  '''
79
-
80
  with gr.Blocks(css=css) as demo:
81
  gr.Markdown("""# 梦羽的模型生成器
82
- ### 快速生成NoobAIXL v1.0的模型图片""")
83
  with gr.Group():
84
  with gr.Row():
85
  prompt = gr.Text(
@@ -147,7 +369,7 @@ with gr.Blocks(css=css) as demo:
147
  outputs=[result, seed],
148
  fn=infer
149
  )
150
-
151
  use_negative_prompt.change(
152
  fn=lambda x: gr.update(visible=x),
153
  inputs=use_negative_prompt,
@@ -155,7 +377,7 @@ with gr.Blocks(css=css) as demo:
155
  )
156
 
157
  gr.on(
158
- triggers=[prompt.submit,run_button.click],
159
  fn=infer,
160
  inputs=[
161
  prompt,
 
3
  import numpy as np
4
  import spaces
5
  import torch
6
+ from diffusers import AutoPipelineForText2Image, AutoencoderKL
7
  from compel import Compel, ReturnedEmbeddingsType
8
 
9
+ import re
10
+
11
+ # =====================================
12
+ # Prompt weights
13
+ # =====================================
14
+ import torch
15
+ import re
16
+ def parse_prompt_attention(text):
17
+ re_attention = re.compile(r"""
18
+ \\\(|
19
+ \\\)|
20
+ \\\[|
21
+ \\]|
22
+ \\\\|
23
+ \\|
24
+ \(|
25
+ \[|
26
+ :([+-]?[.\d]+)\)|
27
+ \)|
28
+ ]|
29
+ [^\\()\[\]:]+|
30
+ :
31
+ """, re.X)
32
+
33
+ res = []
34
+ round_brackets = []
35
+ square_brackets = []
36
+
37
+ round_bracket_multiplier = 1.1
38
+ square_bracket_multiplier = 1 / 1.1
39
+
40
+ def multiply_range(start_position, multiplier):
41
+ for p in range(start_position, len(res)):
42
+ res[p][1] *= multiplier
43
+
44
+ for m in re_attention.finditer(text):
45
+ text = m.group(0)
46
+ weight = m.group(1)
47
+
48
+ if text.startswith('\\'):
49
+ res.append([text[1:], 1.0])
50
+ elif text == '(':
51
+ round_brackets.append(len(res))
52
+ elif text == '[':
53
+ square_brackets.append(len(res))
54
+ elif weight is not None and len(round_brackets) > 0:
55
+ multiply_range(round_brackets.pop(), float(weight))
56
+ elif text == ')' and len(round_brackets) > 0:
57
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
58
+ elif text == ']' and len(square_brackets) > 0:
59
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
60
+ else:
61
+ parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
62
+ for i, part in enumerate(parts):
63
+ if i > 0:
64
+ res.append(["BREAK", -1])
65
+ res.append([part, 1.0])
66
+
67
+ for pos in round_brackets:
68
+ multiply_range(pos, round_bracket_multiplier)
69
+
70
+ for pos in square_brackets:
71
+ multiply_range(pos, square_bracket_multiplier)
72
+
73
+ if len(res) == 0:
74
+ res = [["", 1.0]]
75
+
76
+ # merge runs of identical weights
77
+ i = 0
78
+ while i + 1 < len(res):
79
+ if res[i][1] == res[i + 1][1]:
80
+ res[i][0] += res[i + 1][0]
81
+ res.pop(i + 1)
82
+ else:
83
+ i += 1
84
+
85
+ return res
86
+
87
+ def prompt_attention_to_invoke_prompt(attention):
88
+ tokens = []
89
+ for text, weight in attention:
90
+ # Round weight to 2 decimal places
91
+ weight = round(weight, 2)
92
+ if weight == 1.0:
93
+ tokens.append(text)
94
+ elif weight < 1.0:
95
+ if weight < 0.8:
96
+ tokens.append(f"({text}){weight}")
97
+ else:
98
+ tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
99
+ else:
100
+ if weight < 1.3:
101
+ tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
102
+ else:
103
+ tokens.append(f"({text}){weight}")
104
+ return "".join(tokens)
105
+
106
+ def concat_tensor(t):
107
+ t_list = torch.split(t, 1, dim=0)
108
+ t = torch.cat(t_list, dim=1)
109
+ return t
110
+
111
+ def merge_embeds(prompt_chanks, compel):
112
+ num_chanks = len(prompt_chanks)
113
+ if num_chanks != 0:
114
+ power_prompt = 1/(num_chanks*(num_chanks+1)//2)
115
+ prompt_embs = compel(prompt_chanks)
116
+ t_list = list(torch.split(prompt_embs, 1, dim=0))
117
+ for i in range(num_chanks):
118
+ t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
119
+ prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
120
+ else:
121
+ prompt_emb = compel('')
122
+ return prompt_emb
123
+
124
+ def detokenize(chunk, actual_prompt):
125
+ chunk[-1] = chunk[-1].replace('</w>', '')
126
+ chanked_prompt = ''.join(chunk).strip()
127
+ while '</w>' in chanked_prompt:
128
+ if actual_prompt[chanked_prompt.find('</w>')] == ' ':
129
+ chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
130
+ else:
131
+ chanked_prompt = chanked_prompt.replace('</w>', '', 1)
132
+ actual_prompt = actual_prompt.replace(chanked_prompt,'')
133
+ return chanked_prompt.strip(), actual_prompt.strip()
134
+
135
+ def tokenize_line(line, tokenizer): # split into chunks
136
+ actual_prompt = line.lower().strip()
137
+ actual_tokens = tokenizer.tokenize(actual_prompt)
138
+ max_tokens = tokenizer.model_max_length - 2
139
+ comma_token = tokenizer.tokenize(',')[0]
140
+
141
+ chunks = []
142
+ chunk = []
143
+ for item in actual_tokens:
144
+ chunk.append(item)
145
+ if len(chunk) == max_tokens:
146
+ if chunk[-1] != comma_token:
147
+ for i in range(max_tokens-1, -1, -1):
148
+ if chunk[i] == comma_token:
149
+ actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
150
+ chunks.append(actual_chunk)
151
+ chunk = chunk[i+1:]
152
+ break
153
+ else:
154
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
155
+ chunks.append(actual_chunk)
156
+ chunk = []
157
+ else:
158
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
159
+ chunks.append(actual_chunk)
160
+ chunk = []
161
+ if chunk:
162
+ actual_chunk, _ = detokenize(chunk, actual_prompt)
163
+ chunks.append(actual_chunk)
164
+
165
+ return chunks
166
+
167
+ def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):
168
+
169
+ if compel_process_sd:
170
+ return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel)
171
+ else:
172
+ # fix bug weights conversion excessive emphasis
173
+ prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
174
+
175
+ # Convert to Compel
176
+ attention = parse_prompt_attention(prompt)
177
+ global_attention_chanks = []
178
+
179
+ for att in attention:
180
+ for chank in att[0].split(','):
181
+ temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer)
182
+ for small_chank in temp_prompt_chanks:
183
+ temp_dict = {
184
+ "weight": round(att[1], 2),
185
+ "lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')),
186
+ "prompt": f'{small_chank},'
187
+ }
188
+ global_attention_chanks.append(temp_dict)
189
+
190
+ max_tokens = pipeline.tokenizer.model_max_length - 2
191
+ global_prompt_chanks = []
192
+ current_list = []
193
+ current_length = 0
194
+ for item in global_attention_chanks:
195
+ if current_length + item['lenght'] > max_tokens:
196
+ global_prompt_chanks.append(current_list)
197
+ current_list = [[item['prompt'], item['weight']]]
198
+ current_length = item['lenght']
199
+ else:
200
+ if not current_list:
201
+ current_list.append([item['prompt'], item['weight']])
202
+ else:
203
+ if item['weight'] != current_list[-1][1]:
204
+ current_list.append([item['prompt'], item['weight']])
205
+ else:
206
+ current_list[-1][0] += f" {item['prompt']}"
207
+ current_length += item['lenght']
208
+ if current_list:
209
+ global_prompt_chanks.append(current_list)
210
+
211
+ if only_convert_string:
212
+ return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks])
213
+
214
+ return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel)
215
+
216
+ def add_comma_after_pattern_ti(text):
217
+ pattern = re.compile(r'\b\w+_\d+\b')
218
+ modified_text = pattern.sub(lambda x: x.group() + ',', text)
219
+ return modified_text
220
+
221
  if not torch.cuda.is_available():
222
  DESCRIPTION += "\n<p>你现在运行在CPU上 但是此项目只支持GPU.</p>"
223
 
 
226
 
227
  if torch.cuda.is_available():
228
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
 
 
229
  pipe = AutoPipelineForText2Image.from_pretrained(
230
  "John6666/noobai-xl-nai-xl-epsilonpred10version-sdxl",
231
  vae=vae,
 
233
  use_safetensors=True,
234
  add_watermarker=False
235
  )
 
236
  pipe.to("cuda")
237
 
238
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
 
256
  ):
257
  seed = int(randomize_seed_fn(seed, randomize_seed))
258
  generator = torch.Generator().manual_seed(seed)
259
+ # 初始化 Compel 实例
260
+ compel = Compel(
261
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
262
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
263
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
264
+ requires_pooled=[False, True],
265
+ truncate_long_prompts=False
266
+ )
267
+ # 在 infer 函数中调用 get_embed_new
268
+ if not use_negative_prompt:
269
+ negative_prompt = ""
270
+ prompt = get_embed_new(prompt, pipe, compel, only_convert_string=True)
271
+ negative_prompt = get_embed_new(negative_prompt, pipe, compel, only_convert_string=True)
272
+ conditioning, pooled = compel([prompt, negative_prompt]) # 必须同时处理来保证长度相等
273
 
274
+ # 在调用 pipe 时,使用新的参数名称(确保参数名称正确)
275
  image = pipe(
276
+ prompt_embeds=conditioning[0:1],
277
+ pooled_prompt_embeds=pooled[0:1],
278
+ negative_prompt_embeds=conditioning[1:2],
279
+ negative_pooled_prompt_embeds=pooled[1:2],
280
  width=width,
281
  height=height,
282
  guidance_scale=guidance_scale,
 
298
  visibility: hidden
299
  }
300
  '''
301
+
302
  with gr.Blocks(css=css) as demo:
303
  gr.Markdown("""# 梦羽的模型生成器
304
+ ### 快速生成NoobAIXL v0.5的模型图片 V1.0模型在另一个项目上""")
305
  with gr.Group():
306
  with gr.Row():
307
  prompt = gr.Text(
 
369
  outputs=[result, seed],
370
  fn=infer
371
  )
372
+
373
  use_negative_prompt.change(
374
  fn=lambda x: gr.update(visible=x),
375
  inputs=use_negative_prompt,
 
377
  )
378
 
379
  gr.on(
380
+ triggers=[prompt.submit, run_button.click],
381
  fn=infer,
382
  inputs=[
383
  prompt,