xiaozaa commited on
Commit
4ae4b3e
·
1 Parent(s): fbd9231

release a lora version

Browse files
Files changed (3) hide show
  1. README.md +22 -6
  2. test_lora.png +0 -0
  3. tryon_inference_lora.py +134 -0
README.md CHANGED
@@ -4,11 +4,13 @@ An state-of-the-art virtual try-on solution that combines the power of [CATVTON]
4
  Also inspired by [In-Context LoRA](https://arxiv.org/abs/2410.23775) for prompt engineering.
5
 
6
  ## Update
7
- [![SOTA](https://img.shields.io/badge/SOTA-FID%205.59-brightgreen)](https://drive.google.com/file/d/1T2W5R1xH_uszGVD8p6UUAtWyx43rxGmI/view?usp=sharing)
8
- [![Dataset](https://img.shields.io/badge/Dataset-VITON--HD-blue)](https://github.com/shadow2496/VITON-HD)
9
 
10
  ---
11
- **Latest Achievement** (2024/11/24):
 
 
 
 
12
  - Released FID score and gradio demo
13
  - CatVton-Flux-Alpha achieved **SOTA** performance with FID: `5.593255043029785` on VITON-HD dataset. Test configuration: scale 30, step 30. My VITON-HD test inferencing results available [here](https://drive.google.com/file/d/1T2W5R1xH_uszGVD8p6UUAtWyx43rxGmI/view?usp=sharing)
14
 
@@ -22,8 +24,8 @@ Also inspired by [In-Context LoRA](https://arxiv.org/abs/2410.23775) for prompt
22
  | ![Original](example/person/00008_00.jpg) | ![Garment](example/garment/00034_00.jpg) | ![Result](example/result/3.png) |
23
 
24
  ## Model Weights
25
- Hugging Face: 🤗 [catvton-flux-alpha](https://huggingface.co/xiaozaa/catvton-flux-alpha)
26
-
27
  The model weights are trained on the [VITON-HD](https://github.com/shadow2496/VITON-HD) dataset.
28
 
29
  ## Prerequisites
@@ -40,6 +42,19 @@ huggingface-cli login
40
  ## Usage
41
 
42
  Run the following command to try on an image:
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ```bash
44
  python tryon_inference.py \
45
  --image ./example/person/00008_00.jpg \
@@ -64,7 +79,8 @@ Gradio demo:
64
  - [x] Release the FID score
65
  - [x] Add gradio demo
66
  - [ ] Release updated weights with better performance
67
- - [ ] Train a smaller model
 
68
 
69
  ## Citation
70
 
 
4
  Also inspired by [In-Context LoRA](https://arxiv.org/abs/2410.23775) for prompt engineering.
5
 
6
  ## Update
 
 
7
 
8
  ---
9
+ **Latest Achievement**
10
+ (2024/11/25):
11
+ - Released lora weights.
12
+
13
+ (2024/11/24):
14
  - Released FID score and gradio demo
15
  - CatVton-Flux-Alpha achieved **SOTA** performance with FID: `5.593255043029785` on VITON-HD dataset. Test configuration: scale 30, step 30. My VITON-HD test inferencing results available [here](https://drive.google.com/file/d/1T2W5R1xH_uszGVD8p6UUAtWyx43rxGmI/view?usp=sharing)
16
 
 
24
  | ![Original](example/person/00008_00.jpg) | ![Garment](example/garment/00034_00.jpg) | ![Result](example/result/3.png) |
25
 
26
  ## Model Weights
27
+ LORA weights in Hugging Face: 🤗 [catvton-flux-alpha](https://huggingface.co/xiaozaa/catvton-flux-alpha)
28
+ Fine-tuning weights in Hugging Face: 🤗 [catvton-flux-lora-alpha](https://huggingface.co/xiaozaa/catvton-flux-lora-alpha)
29
  The model weights are trained on the [VITON-HD](https://github.com/shadow2496/VITON-HD) dataset.
30
 
31
  ## Prerequisites
 
42
  ## Usage
43
 
44
  Run the following command to try on an image:
45
+
46
+ LORA version:
47
+ ```bash
48
+ python tryon_inference_lora.py \
49
+ --image ./example/person/00008_00.jpg \
50
+ --mask ./example/person/00008_00_mask.png \
51
+ --garment ./example/garment/00034_00.jpg \
52
+ --seed 4096 \
53
+ --output_tryon test_lora.png \
54
+ --steps 30
55
+ ```
56
+
57
+ Fine-tuning version:
58
  ```bash
59
  python tryon_inference.py \
60
  --image ./example/person/00008_00.jpg \
 
79
  - [x] Release the FID score
80
  - [x] Add gradio demo
81
  - [ ] Release updated weights with better performance
82
+ - [x] Train a smaller model
83
+ - [ ] Support comfyui
84
 
85
  ## Citation
86
 
test_lora.png ADDED
tryon_inference_lora.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from diffusers.utils import load_image, check_min_version
4
+ from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
5
+ from diffusers import FluxTransformer2DModel
6
+ import numpy as np
7
+ from torchvision import transforms
8
+
9
+ def run_inference(
10
+ image_path,
11
+ mask_path,
12
+ garment_path,
13
+ size=(576, 768),
14
+ num_steps=50,
15
+ guidance_scale=30,
16
+ seed=42,
17
+ pipe=None
18
+ ):
19
+ # Build pipeline
20
+ if pipe is None:
21
+ transformer = FluxTransformer2DModel.from_pretrained(
22
+ "xiaozaa/flux1-fill-dev-diffusers", ## The official Flux-Fill weights
23
+ torch_dtype=torch.bfloat16
24
+ )
25
+ print("Start loading LoRA weights")
26
+ state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
27
+ pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights
28
+ weight_name="pytorch_lora_weights.safetensors",
29
+ return_alphas=True
30
+ )
31
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
32
+ if not is_correct_format:
33
+ raise ValueError("Invalid LoRA checkpoint.")
34
+
35
+ FluxFillPipeline.load_lora_into_transformer(
36
+ state_dict=state_dict,
37
+ network_alphas=network_alphas,
38
+ transformer=transformer,
39
+ )
40
+
41
+ pipe = FluxFillPipeline.from_pretrained(
42
+ "black-forest-labs/FLUX.1-dev",
43
+ transformer=transformer,
44
+ torch_dtype=torch.bfloat16
45
+ ).to("cuda")
46
+ else:
47
+ pipe.to("cuda")
48
+
49
+ pipe.transformer.to(torch.bfloat16)
50
+
51
+ # Add transform
52
+ transform = transforms.Compose([
53
+ transforms.ToTensor(),
54
+ transforms.Normalize([0.5], [0.5]) # For RGB images
55
+ ])
56
+ mask_transform = transforms.Compose([
57
+ transforms.ToTensor()
58
+ ])
59
+
60
+ # Load and process images
61
+ # print("image_path", image_path)
62
+ image = load_image(image_path).convert("RGB").resize(size)
63
+ mask = load_image(mask_path).convert("RGB").resize(size)
64
+ garment = load_image(garment_path).convert("RGB").resize(size)
65
+
66
+ # Transform images using the new preprocessing
67
+ image_tensor = transform(image)
68
+ mask_tensor = mask_transform(mask)[:1] # Take only first channel
69
+ garment_tensor = transform(garment)
70
+
71
+ # Create concatenated images
72
+ inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
73
+ garment_mask = torch.zeros_like(mask_tensor)
74
+ extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
75
+
76
+ prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
77
+ f"[IMAGE1] Detailed product shot of a clothing" \
78
+ f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
79
+
80
+ generator = torch.Generator(device="cuda").manual_seed(seed)
81
+
82
+ result = pipe(
83
+ height=size[1],
84
+ width=size[0] * 2,
85
+ image=inpaint_image,
86
+ mask_image=extended_mask,
87
+ num_inference_steps=num_steps,
88
+ generator=generator,
89
+ max_sequence_length=512,
90
+ guidance_scale=guidance_scale,
91
+ prompt=prompt,
92
+ ).images[0]
93
+
94
+ # Split and save results
95
+ width = size[0]
96
+ garment_result = result.crop((0, 0, width, size[1]))
97
+ tryon_result = result.crop((width, 0, width * 2, size[1]))
98
+
99
+ return garment_result, tryon_result
100
+
101
+ def main():
102
+ parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
103
+ parser.add_argument('--image', required=True, help='Path to the model image')
104
+ parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
105
+ parser.add_argument('--garment', required=True, help='Path to the garment image')
106
+ parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
107
+ parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
108
+ parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
109
+ parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
110
+ parser.add_argument('--seed', type=int, default=0, help='Random seed')
111
+ parser.add_argument('--width', type=int, default=576, help='Width')
112
+ parser.add_argument('--height', type=int, default=768, help='Height')
113
+
114
+ args = parser.parse_args()
115
+
116
+ check_min_version("0.30.2")
117
+
118
+ garment_result, tryon_result = run_inference(
119
+ image_path=args.image,
120
+ mask_path=args.mask,
121
+ garment_path=args.garment,
122
+ num_steps=args.steps,
123
+ guidance_scale=args.guidance_scale,
124
+ seed=args.seed,
125
+ size=(args.width, args.height)
126
+ )
127
+ output_tryon_path=args.output_tryon
128
+
129
+ tryon_result.save(output_tryon_path)
130
+
131
+ print("Successfully saved garment and try-on images")
132
+
133
+ if __name__ == "__main__":
134
+ main()