adaface-neurips commited on
Commit
ad6476f
·
1 Parent(s): 57aa583

Better organize code

Browse files
Files changed (2) hide show
  1. app.py +2 -27
  2. pipline_ConsistentID.py +38 -56
app.py CHANGED
@@ -10,11 +10,6 @@ from PIL import Image
10
  from diffusers.utils import load_image
11
  from diffusers import EulerDiscreteScheduler
12
  from pipline_ConsistentID import ConsistentIDPipeline
13
- from huggingface_hub import hf_hub_download
14
- ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
15
- ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
16
- ### Thanks for the open source of face-parsing model.
17
- from models.BiSeNet.model import BiSeNet
18
 
19
  # zero = torch.Tensor([0]).cuda()
20
  # print(zero.device) # <-- 'cpu' 🤔
@@ -26,9 +21,6 @@ script_directory = os.path.dirname(os.path.realpath(__file__))
26
 
27
  # download ConsistentID checkpoint to cache
28
  base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
29
- consistentID_path = hf_hub_download(repo_id="JackAILab/ConsistentID",
30
- filename="ConsistentID-v1.bin", repo_type="model",
31
- local_dir="./models")
32
 
33
  ### Load base model
34
  pipe = ConsistentIDPipeline.from_pretrained(
@@ -38,30 +30,13 @@ pipe = ConsistentIDPipeline.from_pretrained(
38
  variant="fp16"
39
  ).to(device)
40
 
41
- ### Load other pretrained models
42
- ## BiSenet
43
- bise_net_cp_path = hf_hub_download(repo_id="JackAILab/ConsistentID",
44
- filename="face_parsing.pth", local_dir="./models")
45
- bise_net = BiSeNet(n_classes = 19)
46
- bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu")) # device fail
47
- bise_net.cuda()
48
-
49
  ### Load consistentID_model checkpoint
50
  pipe.load_ConsistentID_model(
51
- os.path.dirname(consistentID_path),
52
- bise_net,
53
- subfolder="",
54
- weight_name=os.path.basename(consistentID_path),
55
- trigger_word="img",
56
  )
57
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
58
 
59
- ### Load to cuda
60
- pipe.to(device)
61
- pipe.clip_encoder.to(device)
62
- pipe.image_proj_model.to(device)
63
- pipe.FacialEncoder.to(device)
64
-
65
 
66
  @spaces.GPU
67
  def process(selected_template_images, custom_image, prompt,
 
10
  from diffusers.utils import load_image
11
  from diffusers import EulerDiscreteScheduler
12
  from pipline_ConsistentID import ConsistentIDPipeline
 
 
 
 
 
13
 
14
  # zero = torch.Tensor([0]).cuda()
15
  # print(zero.device) # <-- 'cpu' 🤔
 
21
 
22
  # download ConsistentID checkpoint to cache
23
  base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
 
 
 
24
 
25
  ### Load base model
26
  pipe = ConsistentIDPipeline.from_pretrained(
 
30
  variant="fp16"
31
  ).to(device)
32
 
 
 
 
 
 
 
 
 
33
  ### Load consistentID_model checkpoint
34
  pipe.load_ConsistentID_model(
35
+ consistentID_weight_path="./models/ConsistentID-v1.bin",
36
+ bise_net_weight_path="./models/face_parsing.pth",
 
 
 
37
  )
38
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
39
 
 
 
 
 
 
 
40
 
41
  @spaces.GPU
42
  def process(selected_template_images, custom_image, prompt,
pipline_ConsistentID.py CHANGED
@@ -17,6 +17,12 @@ from functions import insert_markers_to_prompt, masks_for_unique_values, apply_m
17
  from functions import ProjPlusModel, masks_for_unique_values
18
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
19
  from easydict import EasyDict as edict
 
 
 
 
 
 
20
 
21
  PipelineImageInput = Union[
22
  PIL.Image.Image,
@@ -51,11 +57,8 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
51
  @validate_hf_hub_args
52
  def load_ConsistentID_model(
53
  self,
54
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
55
- bise_net,
56
- weight_name: str,
57
- subfolder: str = '',
58
- trigger_word_ID: str = '<|image|>',
59
  trigger_word_facial: str = '<|facial|>',
60
  # A CLIP ViT-H/14 model trained with the LAION-2B English subset of LAION-5B using OpenCLIP.
61
  # output dim: 1280.
@@ -73,7 +76,7 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
73
  self.clip_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
74
  self.device, dtype=self.torch_dtype
75
  )
76
- self.clip_preprocessor = CLIPImageProcessor()
77
  self.id_image_processor = CLIPImageProcessor()
78
  self.crop_size = 512
79
 
@@ -81,13 +84,22 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
81
  self.app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
82
  self.app.prepare(ctx_id=0, det_size=(640, 640))
83
 
84
- ### BiSeNet
85
- # self.bise_net = BiSeNet(n_classes = 19)
86
- # self.bise_net.cuda() # CUDA must not be initialized in the main process on Spaces with Stateless GPU environment
87
- # self.bise_net_cp=bise_net_cp_path
88
- # self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
89
- self.bise_net = bise_net # load from outside
90
- self.bise_net.eval()
 
 
 
 
 
 
 
 
 
91
  # Colors for all 20 parts
92
  self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
93
  [255, 0, 85], [255, 0, 170],
@@ -108,47 +120,18 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
108
  ).to(self.device, dtype=self.torch_dtype)
109
  self.FacialEncoder = FacialEncoder().to(self.device, dtype=self.torch_dtype)
110
 
111
- # Load the main state dict first.
112
- cache_dir = kwargs.pop("cache_dir", None)
113
- force_download = kwargs.pop("force_download", False)
114
- proxies = kwargs.pop("proxies", None)
115
- local_files_only = kwargs.pop("local_files_only", None)
116
- token = kwargs.pop("token", None)
117
- revision = kwargs.pop("revision", None)
118
-
119
- user_agent = {
120
- "file_type": "attn_procs_weights",
121
- "framework": "pytorch",
122
- }
123
-
124
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
125
- model_file = _get_model_file(
126
- pretrained_model_name_or_path_or_dict,
127
- weights_name=weight_name,
128
- cache_dir=cache_dir,
129
- force_download=force_download,
130
- proxies=proxies,
131
- local_files_only=local_files_only,
132
- use_auth_token=token,
133
- revision=revision,
134
- subfolder=subfolder,
135
- user_agent=user_agent,
136
- )
137
- if weight_name.endswith(".safetensors"):
138
- state_dict = {"id_encoder": {}, "lora_weights": {}}
139
- with safe_open(model_file, framework="pt", device="cpu") as f:
140
- ### TODO safetensors add
141
- for key in f.keys():
142
- if key.startswith("FacialEncoder."):
143
- state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key)
144
- elif key.startswith("image_proj."):
145
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
146
- else:
147
- state_dict = torch.load(model_file, map_location="cpu")
148
  else:
149
- state_dict = pretrained_model_name_or_path_or_dict
150
-
151
- self.trigger_word_ID = trigger_word_ID
152
  self.trigger_word_facial = trigger_word_facial
153
 
154
  self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True)
@@ -159,7 +142,6 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
159
 
160
  # Add trigger word token
161
  if self.tokenizer is not None:
162
- self.tokenizer.add_tokens([self.trigger_word_ID], special_tokens=True)
163
  self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True)
164
 
165
  def set_ip_adapter(self):
@@ -264,7 +246,7 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
264
  image_resize_PIL = image
265
  img = to_tensor(image)
266
  img = torch.unsqueeze(img, 0)
267
- img = img.float().cuda()
268
  out = self.bise_net(img)[0]
269
  parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
270
 
@@ -337,7 +319,7 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
337
  # Remove "<|facial|>" from prompt_face.
338
  # augmented_prompt: 'A person, police officer, half body shot Detail:
339
  # The person has one nose , two ears , two eyes , and a mouth , '
340
- augmented_prompt = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
341
  tokenizer = self.tokenizer
342
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
343
  image_token_id = None
 
17
  from functions import ProjPlusModel, masks_for_unique_values
18
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
19
  from easydict import EasyDict as edict
20
+ from huggingface_hub import hf_hub_download
21
+ ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
22
+ ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
23
+ ### Thanks for the open source of face-parsing model.
24
+ from models.BiSeNet.model import BiSeNet
25
+ import os
26
 
27
  PipelineImageInput = Union[
28
  PIL.Image.Image,
 
57
  @validate_hf_hub_args
58
  def load_ConsistentID_model(
59
  self,
60
+ consistentID_weight_path: str,
61
+ bise_net_weight_path: str,
 
 
 
62
  trigger_word_facial: str = '<|facial|>',
63
  # A CLIP ViT-H/14 model trained with the LAION-2B English subset of LAION-5B using OpenCLIP.
64
  # output dim: 1280.
 
76
  self.clip_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
77
  self.device, dtype=self.torch_dtype
78
  )
79
+ self.clip_preprocessor = CLIPImageProcessor()
80
  self.id_image_processor = CLIPImageProcessor()
81
  self.crop_size = 512
82
 
 
84
  self.app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
85
  self.app.prepare(ctx_id=0, det_size=(640, 640))
86
 
87
+ if not os.path.exists(consistentID_weight_path):
88
+ ### Download pretrained models
89
+ hf_hub_download(repo_id="JackAILab/ConsistentID", repo_type="model",
90
+ filename=os.path.basename(consistentID_weight_path),
91
+ local_dir=os.path.dirname(consistentID_weight_path))
92
+ if not os.path.exists(bise_net_weight_path):
93
+ hf_hub_download(repo_id="JackAILab/ConsistentID",
94
+ filename=os.path.basename(bise_net_weight_path),
95
+ local_dir=os.path.dirname(bise_net_weight_path))
96
+
97
+ bise_net = BiSeNet(n_classes = 19)
98
+ bise_net.load_state_dict(torch.load(bise_net_weight_path, map_location="cpu"))
99
+ bise_net.to(self.device, dtype=self.torch_dtype)
100
+ bise_net.eval()
101
+ self.bise_net = bise_net
102
+
103
  # Colors for all 20 parts
104
  self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
105
  [255, 0, 85], [255, 0, 170],
 
120
  ).to(self.device, dtype=self.torch_dtype)
121
  self.FacialEncoder = FacialEncoder().to(self.device, dtype=self.torch_dtype)
122
 
123
+ if consistentID_weight_path.endswith(".safetensors"):
124
+ state_dict = {"id_encoder": {}, "lora_weights": {}}
125
+ with safe_open(consistentID_weight_path, framework="pt", device="cpu") as f:
126
+ ### TODO safetensors add
127
+ for key in f.keys():
128
+ if key.startswith("FacialEncoder."):
129
+ state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key)
130
+ elif key.startswith("image_proj."):
131
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  else:
133
+ state_dict = torch.load(consistentID_weight_path, map_location="cpu")
134
+
 
135
  self.trigger_word_facial = trigger_word_facial
136
 
137
  self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True)
 
142
 
143
  # Add trigger word token
144
  if self.tokenizer is not None:
 
145
  self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True)
146
 
147
  def set_ip_adapter(self):
 
246
  image_resize_PIL = image
247
  img = to_tensor(image)
248
  img = torch.unsqueeze(img, 0)
249
+ img = img.to(self.device, dtype=self.torch_dtype)
250
  out = self.bise_net(img)[0]
251
  parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
252
 
 
319
  # Remove "<|facial|>" from prompt_face.
320
  # augmented_prompt: 'A person, police officer, half body shot Detail:
321
  # The person has one nose , two ears , two eyes , and a mouth , '
322
+ augmented_prompt = prompt_face.replace("<|facial|>", "")
323
  tokenizer = self.tokenizer
324
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
325
  image_token_id = None