junnyu commited on
Commit
be7aa5c
·
1 Parent(s): 83feac1

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +119 -19
pipeline.py CHANGED
@@ -17,6 +17,7 @@
17
  # Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
18
 
19
  import inspect
 
20
  from pathlib import Path
21
  from typing import Any, Callable, Dict, List, Optional, Union
22
 
@@ -32,7 +33,9 @@ from ppdiffusers.pipelines.stable_diffusion.safety_checker import (
32
  )
33
  from ppdiffusers.schedulers import KarrasDiffusionSchedulers
34
  from ppdiffusers.utils import (
 
35
  logging,
 
36
  randn_tensor,
37
  safetensors_load,
38
  smart_load,
@@ -42,6 +45,64 @@ from ppdiffusers.utils import (
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @paddle.no_grad()
46
  def load_lora(
47
  pipeline,
@@ -164,6 +225,8 @@ class WebUIStableDiffusionPipeline(DiffusionPipeline):
164
  _optional_components = ["safety_checker", "feature_extractor"]
165
  enable_emphasis = True
166
  comma_padding_backtrack = 20
 
 
167
 
168
  def __init__(
169
  self,
@@ -227,7 +290,17 @@ class WebUIStableDiffusionPipeline(DiffusionPipeline):
227
  ]
228
  self.weights_has_changed = False
229
 
230
- def add_ti_embedding_dir(self, embeddings_dir):
 
 
 
 
 
 
 
 
 
 
231
  self.sj.embedding_db.add_embedding_dir(embeddings_dir)
232
  self.sj.embedding_db.load_textual_inversion_embeddings()
233
 
@@ -235,6 +308,30 @@ class WebUIStableDiffusionPipeline(DiffusionPipeline):
235
  self.sj.embedding_db.clear_embedding_dirs()
236
  self.sj.embedding_db.load_textual_inversion_embeddings(True)
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def change_scheduler(self, scheduler_type="ddim"):
239
  self.switch_scheduler(scheduler_type)
240
 
@@ -408,7 +505,6 @@ class WebUIStableDiffusionPipeline(DiffusionPipeline):
408
  callback_steps: Optional[int] = 1,
409
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
410
  clip_skip: int = 1,
411
- lora_dir: str = "./loras",
412
  ):
413
  r"""
414
  Function invoked when calling the pipeline for generation.
@@ -459,10 +555,8 @@ class WebUIStableDiffusionPipeline(DiffusionPipeline):
459
  A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
460
  `self.processor` in
461
  [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
462
- clip_skip (`int`, *optional*, defaults to 0):
463
  CLIP_stop_at_last_layers, if clip_skip <= 1, we will use the last_hidden_state from text_encoder.
464
- lora_dir (`str`, *optional*):
465
- Path to lora which we want to load.
466
  Examples:
467
 
468
  Returns:
@@ -472,6 +566,8 @@ class WebUIStableDiffusionPipeline(DiffusionPipeline):
472
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
473
  (nsfw) content, according to the `safety_checker`.
474
  """
 
 
475
  try:
476
  # 0. Default height and width to unet
477
  height = height or max(self.unet.config.sample_size * self.vae_scale_factor, 512)
@@ -495,19 +591,23 @@ class WebUIStableDiffusionPipeline(DiffusionPipeline):
495
 
496
  prompts, extra_network_data = parse_prompts([prompt])
497
 
498
- if lora_dir is not None and os.path.exists(lora_dir):
499
- lora_mapping = {p.stem: p.absolute() for p in Path(lora_dir).glob("*.safetensors")}
500
- for params in extra_network_data["lora"]:
501
- assert len(params.items) > 0
502
- name = params.items[0]
503
- if name in lora_mapping:
504
- ratio = float(params.items[1]) if len(params.items) > 1 else 1.0
505
- lora_state_dict = smart_load(lora_mapping[name], map_location=paddle.get_device())
506
- self.weights_has_changed = True
507
- load_lora(self, state_dict=lora_state_dict, ratio=ratio)
508
- del lora_state_dict
509
- else:
510
- print(f"We can't find lora weight: {name}! Please make sure that exists!")
 
 
 
 
511
 
512
  self.sj.clip.CLIP_stop_at_last_layers = clip_skip
513
  # 3. Encode input prompt
@@ -1658,7 +1758,7 @@ class EmbeddingDatabase:
1658
  self.previously_displayed_embeddings = ()
1659
 
1660
  def add_embedding_dir(self, path):
1661
- if path is not None:
1662
  self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
1663
 
1664
  def clear_embedding_dirs(self):
 
17
  # Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
18
 
19
  import inspect
20
+ import shutil
21
  from pathlib import Path
22
  from typing import Any, Callable, Dict, List, Optional, Union
23
 
 
33
  )
34
  from ppdiffusers.schedulers import KarrasDiffusionSchedulers
35
  from ppdiffusers.utils import (
36
+ PPDIFFUSERS_CACHE,
37
  logging,
38
+ ppdiffusers_url_download,
39
  randn_tensor,
40
  safetensors_load,
41
  smart_load,
 
45
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
 
47
 
48
+ import copy
49
+ import os
50
+ import os.path
51
+
52
+ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
53
+
54
+ # lark omegaconf
55
+
56
+
57
+ def get_civitai_download_url(display_url, url_prefix="https://civitai.com"):
58
+ if "api/download" in display_url:
59
+ return display_url
60
+ import bs4
61
+ import requests
62
+
63
+ headers = {
64
+ "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36 QIHU 360SE"
65
+ }
66
+ r = requests.get(display_url, headers=headers)
67
+ soup = bs4.BeautifulSoup(r.text, "lxml")
68
+ download_url = None
69
+ for a in soup.find_all("a", href=True):
70
+ if "Download" in str(a):
71
+ download_url = url_prefix + a["href"].split("?")[0]
72
+ break
73
+ return download_url
74
+
75
+
76
+ def http_file_name(
77
+ url: str,
78
+ *,
79
+ proxies=None,
80
+ headers: Optional[Dict[str, str]] = None,
81
+ timeout=10.0,
82
+ max_retries=0,
83
+ ):
84
+ """
85
+ Get a remote file name.
86
+ """
87
+ headers = copy.deepcopy(headers) or {}
88
+ r = _request_wrapper(
89
+ method="GET",
90
+ url=url,
91
+ stream=True,
92
+ proxies=proxies,
93
+ headers=headers,
94
+ timeout=timeout,
95
+ max_retries=max_retries,
96
+ )
97
+ hf_raise_for_status(r)
98
+ displayed_name = url.split("/")[-1]
99
+ content_disposition = r.headers.get("Content-Disposition")
100
+ if content_disposition is not None and "filename=" in content_disposition:
101
+ # Means file is on CDN
102
+ displayed_name = content_disposition.split("filename=")[-1]
103
+ return displayed_name
104
+
105
+
106
  @paddle.no_grad()
107
  def load_lora(
108
  pipeline,
 
225
  _optional_components = ["safety_checker", "feature_extractor"]
226
  enable_emphasis = True
227
  comma_padding_backtrack = 20
228
+ LORA_DIR = os.path.join(PPDIFFUSERS_CACHE, "lora")
229
+ TI_DIR = os.path.join(PPDIFFUSERS_CACHE, "textual_inversion")
230
 
231
  def __init__(
232
  self,
 
290
  ]
291
  self.weights_has_changed = False
292
 
293
+ # register_state_dict_hook to fix text_encoder, when we save_pretrained text model.
294
+ def map_to(state_dict, *args, **kwargs):
295
+ if "text_model.token_embedding.wrapped.weight" in state_dict:
296
+ state_dict["text_model.token_embedding.weight"] = state_dict.pop(
297
+ "text_model.token_embedding.wrapped.weight"
298
+ )
299
+ return state_dict
300
+
301
+ self.text_encoder.register_state_dict_hook(map_to)
302
+
303
+ def add_ti_embedding_dir(self, embeddings_dir=None):
304
  self.sj.embedding_db.add_embedding_dir(embeddings_dir)
305
  self.sj.embedding_db.load_textual_inversion_embeddings()
306
 
 
308
  self.sj.embedding_db.clear_embedding_dirs()
309
  self.sj.embedding_db.load_textual_inversion_embeddings(True)
310
 
311
+ def download_civitai_lora_file(self, url):
312
+ if os.path.isfile(url):
313
+ dst = os.path.join(self.LORA_DIR, os.path.basename(url))
314
+ shutil.copyfile(url, dst)
315
+ return dst
316
+
317
+ download_url = get_civitai_download_url(url) or url
318
+ file_path = ppdiffusers_url_download(
319
+ download_url, cache_dir=self.LORA_DIR, filename=http_file_name(download_url).strip('"')
320
+ )
321
+ return file_path
322
+
323
+ def download_civitai_ti_file(self, url):
324
+ if os.path.isfile(url):
325
+ dst = os.path.join(self.TI_DIR, os.path.basename(url))
326
+ shutil.copyfile(url, dst)
327
+ return dst
328
+
329
+ download_url = get_civitai_download_url(url) or url
330
+ file_path = ppdiffusers_url_download(
331
+ download_url, cache_dir=self.TI_DIR, filename=http_file_name(download_url).strip('"')
332
+ )
333
+ return file_path
334
+
335
  def change_scheduler(self, scheduler_type="ddim"):
336
  self.switch_scheduler(scheduler_type)
337
 
 
505
  callback_steps: Optional[int] = 1,
506
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
507
  clip_skip: int = 1,
 
508
  ):
509
  r"""
510
  Function invoked when calling the pipeline for generation.
 
555
  A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
556
  `self.processor` in
557
  [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
558
+ clip_skip (`int`, *optional*, defaults to 1):
559
  CLIP_stop_at_last_layers, if clip_skip <= 1, we will use the last_hidden_state from text_encoder.
 
 
560
  Examples:
561
 
562
  Returns:
 
566
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
567
  (nsfw) content, according to the `safety_checker`.
568
  """
569
+ self.add_ti_embedding_dir(self.TI_DIR)
570
+
571
  try:
572
  # 0. Default height and width to unet
573
  height = height or max(self.unet.config.sample_size * self.vae_scale_factor, 512)
 
591
 
592
  prompts, extra_network_data = parse_prompts([prompt])
593
 
594
+ if self.LORA_DIR is not None:
595
+ if os.path.exists(self.LORA_DIR):
596
+ lora_mapping = {p.stem: p.absolute() for p in Path(self.LORA_DIR).glob("*.safetensors")}
597
+ for params in extra_network_data["lora"]:
598
+ assert len(params.items) > 0
599
+ name = params.items[0]
600
+ if name in lora_mapping:
601
+ ratio = float(params.items[1]) if len(params.items) > 1 else 1.0
602
+ lora_state_dict = smart_load(lora_mapping[name], map_location=paddle.get_device())
603
+ self.weights_has_changed = True
604
+ load_lora(self, state_dict=lora_state_dict, ratio=ratio)
605
+ del lora_state_dict
606
+ else:
607
+ print(f"We can't find lora weight: {name}! Please make sure that exists!")
608
+ else:
609
+ if len(extra_network_data["lora"]) > 0:
610
+ print(f"{self.LORA_DIR} not exists, so we cant load loras!")
611
 
612
  self.sj.clip.CLIP_stop_at_last_layers = clip_skip
613
  # 3. Encode input prompt
 
1758
  self.previously_displayed_embeddings = ()
1759
 
1760
  def add_embedding_dir(self, path):
1761
+ if path is not None and path not in self.embedding_dirs:
1762
  self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
1763
 
1764
  def clear_embedding_dirs(self):