hysts HF staff commited on
Commit
dda8135
1 Parent(s): 1d528a6
Files changed (6) hide show
  1. .pre-commit-config.yaml +60 -36
  2. .style.yapf +0 -5
  3. .vscode/settings.json +30 -0
  4. app.py +34 -48
  5. model.py +41 -46
  6. style.css +1 -0
.pre-commit-config.yaml CHANGED
@@ -1,37 +1,61 @@
1
- exclude: ^patch.*
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
- hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.6.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.13.2
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ [
34
+ "types-python-slugify",
35
+ "types-requests",
36
+ "types-PyYAML",
37
+ "types-pytz",
38
+ ]
39
+ - repo: https://github.com/psf/black
40
+ rev: 24.4.2
41
+ hooks:
42
+ - id: black
43
+ language_version: python3.10
44
+ args: ["--line-length", "119"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.7.1
47
+ hooks:
48
+ - id: nbstripout
49
+ args:
50
+ [
51
+ "--extra-keys",
52
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
53
+ ]
54
+ - repo: https://github.com/nbQA-dev/nbQA
55
+ rev: 1.8.5
56
+ hooks:
57
+ - id: nbqa-black
58
+ - id: nbqa-pyupgrade
59
+ args: ["--py37-plus"]
60
+ - id: nbqa-isort
61
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
app.py CHANGED
@@ -8,14 +8,14 @@ import gradio as gr
8
 
9
  from model import Model
10
 
11
- DESCRIPTION = '''# [HairCLIP](https://github.com/wty-ustc/HairCLIP)
12
 
13
  <center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
14
- '''
15
 
16
 
17
  def load_hairstyle_list() -> list[str]:
18
- with open('HairCLIP/mapper/hairstyle_list.txt') as f:
19
  lines = [line.strip() for line in f.readlines()]
20
  lines = [line[:-10] for line in lines]
21
  return lines
@@ -27,78 +27,64 @@ def set_example_image(example: list) -> dict:
27
 
28
  def update_step2_components(choice: str) -> tuple[dict, dict]:
29
  return (
30
- gr.Dropdown.update(visible=choice in ['hairstyle', 'both']),
31
- gr.Textbox.update(visible=choice in ['color', 'both']),
32
  )
33
 
34
 
35
  model = Model()
36
 
37
- with gr.Blocks(css='style.css') as demo:
38
  gr.Markdown(DESCRIPTION)
39
  with gr.Box():
40
- gr.Markdown('## Step 1')
41
  with gr.Row():
42
  with gr.Column():
43
  with gr.Row():
44
- input_image = gr.Image(label='Input Image',
45
- type='filepath')
46
  with gr.Row():
47
- preprocess_button = gr.Button('Preprocess')
48
  with gr.Column():
49
- aligned_face = gr.Image(label='Aligned Face',
50
- type='pil',
51
- interactive=False)
52
  with gr.Column():
53
- reconstructed_face = gr.Image(label='Reconstructed Face',
54
- type='numpy')
55
  latent = gr.Variable()
56
 
57
  with gr.Row():
58
- paths = sorted(pathlib.Path('images').glob('*.jpg'))
59
- gr.Examples(examples=[[path.as_posix()] for path in paths],
60
- inputs=input_image)
61
 
62
  with gr.Box():
63
- gr.Markdown('## Step 2')
64
  with gr.Row():
65
  with gr.Column():
66
  with gr.Row():
67
  editing_type = gr.Radio(
68
- label='Editing Type',
69
- choices=['hairstyle', 'color', 'both'],
70
- value='both',
71
- type='value')
72
  with gr.Row():
73
  hairstyles = load_hairstyle_list()
74
- hairstyle_index = gr.Dropdown(label='Hairstyle',
75
- choices=hairstyles,
76
- value='afro',
77
- type='index')
78
  with gr.Row():
79
- color_description = gr.Textbox(label='Color', value='red')
80
  with gr.Row():
81
- run_button = gr.Button('Run')
82
 
83
  with gr.Column():
84
- result = gr.Image(label='Result')
85
-
86
- preprocess_button.click(fn=model.detect_and_align_face,
87
- inputs=input_image,
88
- outputs=aligned_face)
89
- aligned_face.change(fn=model.reconstruct_face,
90
- inputs=aligned_face,
91
- outputs=[reconstructed_face, latent])
92
- editing_type.change(fn=update_step2_components,
93
- inputs=editing_type,
94
- outputs=[hairstyle_index, color_description])
95
- run_button.click(fn=model.generate,
96
- inputs=[
97
- editing_type,
98
- hairstyle_index,
99
- color_description,
100
- latent,
101
- ],
102
- outputs=result)
103
 
104
  demo.queue(max_size=10).launch()
 
8
 
9
  from model import Model
10
 
11
+ DESCRIPTION = """# [HairCLIP](https://github.com/wty-ustc/HairCLIP)
12
 
13
  <center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
14
+ """
15
 
16
 
17
  def load_hairstyle_list() -> list[str]:
18
+ with open("HairCLIP/mapper/hairstyle_list.txt") as f:
19
  lines = [line.strip() for line in f.readlines()]
20
  lines = [line[:-10] for line in lines]
21
  return lines
 
27
 
28
  def update_step2_components(choice: str) -> tuple[dict, dict]:
29
  return (
30
+ gr.Dropdown.update(visible=choice in ["hairstyle", "both"]),
31
+ gr.Textbox.update(visible=choice in ["color", "both"]),
32
  )
33
 
34
 
35
  model = Model()
36
 
37
+ with gr.Blocks(css="style.css") as demo:
38
  gr.Markdown(DESCRIPTION)
39
  with gr.Box():
40
+ gr.Markdown("## Step 1")
41
  with gr.Row():
42
  with gr.Column():
43
  with gr.Row():
44
+ input_image = gr.Image(label="Input Image", type="filepath")
 
45
  with gr.Row():
46
+ preprocess_button = gr.Button("Preprocess")
47
  with gr.Column():
48
+ aligned_face = gr.Image(label="Aligned Face", type="pil", interactive=False)
 
 
49
  with gr.Column():
50
+ reconstructed_face = gr.Image(label="Reconstructed Face", type="numpy")
 
51
  latent = gr.Variable()
52
 
53
  with gr.Row():
54
+ paths = sorted(pathlib.Path("images").glob("*.jpg"))
55
+ gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)
 
56
 
57
  with gr.Box():
58
+ gr.Markdown("## Step 2")
59
  with gr.Row():
60
  with gr.Column():
61
  with gr.Row():
62
  editing_type = gr.Radio(
63
+ label="Editing Type", choices=["hairstyle", "color", "both"], value="both", type="value"
64
+ )
 
 
65
  with gr.Row():
66
  hairstyles = load_hairstyle_list()
67
+ hairstyle_index = gr.Dropdown(label="Hairstyle", choices=hairstyles, value="afro", type="index")
 
 
 
68
  with gr.Row():
69
+ color_description = gr.Textbox(label="Color", value="red")
70
  with gr.Row():
71
+ run_button = gr.Button("Run")
72
 
73
  with gr.Column():
74
+ result = gr.Image(label="Result")
75
+
76
+ preprocess_button.click(fn=model.detect_and_align_face, inputs=input_image, outputs=aligned_face)
77
+ aligned_face.change(fn=model.reconstruct_face, inputs=aligned_face, outputs=[reconstructed_face, latent])
78
+ editing_type.change(fn=update_step2_components, inputs=editing_type, outputs=[hairstyle_index, color_description])
79
+ run_button.click(
80
+ fn=model.generate,
81
+ inputs=[
82
+ editing_type,
83
+ hairstyle_index,
84
+ color_description,
85
+ latent,
86
+ ],
87
+ outputs=result,
88
+ )
 
 
 
 
89
 
90
  demo.queue(max_size=10).launch()
model.py CHANGED
@@ -15,22 +15,22 @@ import torch
15
  import torch.nn as nn
16
  import torchvision.transforms as T
17
 
18
- if os.getenv('SYSTEM') == 'spaces' and not torch.cuda.is_available():
19
- with open('patch.e4e') as f:
20
- subprocess.run('patch -p1'.split(), cwd='encoder4editing', stdin=f)
21
- with open('patch.hairclip') as f:
22
- subprocess.run('patch -p1'.split(), cwd='HairCLIP', stdin=f)
23
 
24
  app_dir = pathlib.Path(__file__).parent
25
 
26
- e4e_dir = app_dir / 'encoder4editing'
27
  sys.path.insert(0, e4e_dir.as_posix())
28
 
29
  from models.psp import pSp
30
  from utils.alignment import align_face
31
 
32
- hairclip_dir = app_dir / 'HairCLIP'
33
- mapper_dir = hairclip_dir / 'mapper'
34
  sys.path.insert(0, hairclip_dir.as_posix())
35
  sys.path.insert(0, mapper_dir.as_posix())
36
 
@@ -40,8 +40,7 @@ from mapper.hairclip_mapper import HairCLIPMapper
40
 
41
  class Model:
42
  def __init__(self):
43
- self.device = torch.device(
44
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
45
  self.landmark_model = self._create_dlib_landmark_model()
46
  self.e4e = self._load_e4e()
47
  self.hairclip = self._load_hairclip()
@@ -50,17 +49,16 @@ class Model:
50
  @staticmethod
51
  def _create_dlib_landmark_model():
52
  path = huggingface_hub.hf_hub_download(
53
- 'public-data/dlib_face_landmark_model',
54
- 'shape_predictor_68_face_landmarks.dat')
55
  return dlib.shape_predictor(path)
56
 
57
  def _load_e4e(self) -> nn.Module:
58
- ckpt_path = huggingface_hub.hf_hub_download('public-data/e4e',
59
- 'e4e_ffhq_encode.pt')
60
- ckpt = torch.load(ckpt_path, map_location='cpu')
61
- opts = ckpt['opts']
62
- opts['device'] = self.device.type
63
- opts['checkpoint_path'] = ckpt_path
64
  opts = argparse.Namespace(**opts)
65
  model = pSp(opts)
66
  model.to(self.device)
@@ -68,16 +66,15 @@ class Model:
68
  return model
69
 
70
  def _load_hairclip(self) -> nn.Module:
71
- ckpt_path = huggingface_hub.hf_hub_download('public-data/HairCLIP',
72
- 'hairclip.pt')
73
- ckpt = torch.load(ckpt_path, map_location='cpu')
74
- opts = ckpt['opts']
75
- opts['device'] = self.device.type
76
- opts['checkpoint_path'] = ckpt_path
77
- opts['editing_type'] = 'both'
78
- opts['input_type'] = 'text'
79
- opts['hairstyle_description'] = 'HairCLIP/mapper/hairstyle_list.txt'
80
- opts['color_description'] = 'red'
81
  opts = argparse.Namespace(**opts)
82
  model = HairCLIPMapper(opts)
83
  model.to(self.device)
@@ -86,12 +83,14 @@ class Model:
86
 
87
  @staticmethod
88
  def _create_transform() -> Callable:
89
- transform = T.Compose([
90
- T.Resize(256),
91
- T.CenterCrop(256),
92
- T.ToTensor(),
93
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
94
- ])
 
 
95
  return transform
96
 
97
  def detect_and_align_face(self, image: str) -> PIL.Image.Image:
@@ -107,35 +106,31 @@ class Model:
107
  return tensor.cpu().numpy().transpose(1, 2, 0)
108
 
109
  @torch.inference_mode()
110
- def reconstruct_face(
111
- self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]:
112
  input_data = self.transform(image).unsqueeze(0).to(self.device)
113
- reconstructed_images, latents = self.e4e(input_data,
114
- randomize_noise=False,
115
- return_latents=True)
116
  reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1)
117
  reconstructed = self.postprocess(reconstructed)
118
  return reconstructed, latents[0]
119
 
120
  @torch.inference_mode()
121
- def generate(self, editing_type: str, hairstyle_index: int,
122
- color_description: str, latent: torch.Tensor) -> np.ndarray:
 
123
  opts = self.hairclip.opts
124
  opts.editing_type = editing_type
125
  opts.color_description = color_description
126
 
127
- if editing_type == 'color':
128
  hairstyle_index = 0
129
 
130
  device = torch.device(opts.device)
131
 
132
- dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(),
133
- opts=opts)
134
  w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3]
135
 
136
  w = w.unsqueeze(0).to(device)
137
- hairstyle_text_inputs = hairstyle_text_inputs_list[
138
- hairstyle_index].unsqueeze(0).to(device)
139
  color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device)
140
 
141
  hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
 
15
  import torch.nn as nn
16
  import torchvision.transforms as T
17
 
18
+ if os.getenv("SYSTEM") == "spaces" and not torch.cuda.is_available():
19
+ with open("patch.e4e") as f:
20
+ subprocess.run("patch -p1".split(), cwd="encoder4editing", stdin=f)
21
+ with open("patch.hairclip") as f:
22
+ subprocess.run("patch -p1".split(), cwd="HairCLIP", stdin=f)
23
 
24
  app_dir = pathlib.Path(__file__).parent
25
 
26
+ e4e_dir = app_dir / "encoder4editing"
27
  sys.path.insert(0, e4e_dir.as_posix())
28
 
29
  from models.psp import pSp
30
  from utils.alignment import align_face
31
 
32
+ hairclip_dir = app_dir / "HairCLIP"
33
+ mapper_dir = hairclip_dir / "mapper"
34
  sys.path.insert(0, hairclip_dir.as_posix())
35
  sys.path.insert(0, mapper_dir.as_posix())
36
 
 
40
 
41
  class Model:
42
  def __init__(self):
43
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
44
  self.landmark_model = self._create_dlib_landmark_model()
45
  self.e4e = self._load_e4e()
46
  self.hairclip = self._load_hairclip()
 
49
  @staticmethod
50
  def _create_dlib_landmark_model():
51
  path = huggingface_hub.hf_hub_download(
52
+ "public-data/dlib_face_landmark_model", "shape_predictor_68_face_landmarks.dat"
53
+ )
54
  return dlib.shape_predictor(path)
55
 
56
  def _load_e4e(self) -> nn.Module:
57
+ ckpt_path = huggingface_hub.hf_hub_download("public-data/e4e", "e4e_ffhq_encode.pt")
58
+ ckpt = torch.load(ckpt_path, map_location="cpu")
59
+ opts = ckpt["opts"]
60
+ opts["device"] = self.device.type
61
+ opts["checkpoint_path"] = ckpt_path
 
62
  opts = argparse.Namespace(**opts)
63
  model = pSp(opts)
64
  model.to(self.device)
 
66
  return model
67
 
68
  def _load_hairclip(self) -> nn.Module:
69
+ ckpt_path = huggingface_hub.hf_hub_download("public-data/HairCLIP", "hairclip.pt")
70
+ ckpt = torch.load(ckpt_path, map_location="cpu")
71
+ opts = ckpt["opts"]
72
+ opts["device"] = self.device.type
73
+ opts["checkpoint_path"] = ckpt_path
74
+ opts["editing_type"] = "both"
75
+ opts["input_type"] = "text"
76
+ opts["hairstyle_description"] = "HairCLIP/mapper/hairstyle_list.txt"
77
+ opts["color_description"] = "red"
 
78
  opts = argparse.Namespace(**opts)
79
  model = HairCLIPMapper(opts)
80
  model.to(self.device)
 
83
 
84
  @staticmethod
85
  def _create_transform() -> Callable:
86
+ transform = T.Compose(
87
+ [
88
+ T.Resize(256),
89
+ T.CenterCrop(256),
90
+ T.ToTensor(),
91
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
92
+ ]
93
+ )
94
  return transform
95
 
96
  def detect_and_align_face(self, image: str) -> PIL.Image.Image:
 
106
  return tensor.cpu().numpy().transpose(1, 2, 0)
107
 
108
  @torch.inference_mode()
109
+ def reconstruct_face(self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]:
 
110
  input_data = self.transform(image).unsqueeze(0).to(self.device)
111
+ reconstructed_images, latents = self.e4e(input_data, randomize_noise=False, return_latents=True)
 
 
112
  reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1)
113
  reconstructed = self.postprocess(reconstructed)
114
  return reconstructed, latents[0]
115
 
116
  @torch.inference_mode()
117
+ def generate(
118
+ self, editing_type: str, hairstyle_index: int, color_description: str, latent: torch.Tensor
119
+ ) -> np.ndarray:
120
  opts = self.hairclip.opts
121
  opts.editing_type = editing_type
122
  opts.color_description = color_description
123
 
124
+ if editing_type == "color":
125
  hairstyle_index = 0
126
 
127
  device = torch.device(opts.device)
128
 
129
+ dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(), opts=opts)
 
130
  w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3]
131
 
132
  w = w.unsqueeze(0).to(device)
133
+ hairstyle_text_inputs = hairstyle_text_inputs_list[hairstyle_index].unsqueeze(0).to(device)
 
134
  color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device)
135
 
136
  hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
style.css CHANGED
@@ -1,5 +1,6 @@
1
  h1 {
2
  text-align: center;
 
3
  }
4
 
5
  img#teaser {
 
1
  h1 {
2
  text-align: center;
3
+ display: block;
4
  }
5
 
6
  img#teaser {