readme and new requirements
Browse files- .gitignore +0 -1
- 2_gpu.json +0 -11
- 3_gpu.json +0 -11
- README.md +154 -2
- data_scripts/cal_bbox_by_seg.py +59 -0
- datasets.py +3 -1
- elite.yaml +0 -147
- inference_global.py +10 -8
- inference_global.sh +4 -3
- inference_local.py +12 -9
- inference_local.sh +3 -2
- requirements.txt +11 -0
- train_global.py +1 -3
- train_global.sh +3 -3
- train_local.py +2 -5
- train_local.sh +2 -2
.gitignore
CHANGED
@@ -5,7 +5,6 @@ _sc.py
|
|
5 |
*.ckpt
|
6 |
*.bin
|
7 |
|
8 |
-
checkpoints
|
9 |
.idea
|
10 |
.idea/workspace.xml
|
11 |
.DS_Store
|
|
|
5 |
*.ckpt
|
6 |
*.bin
|
7 |
|
|
|
8 |
.idea
|
9 |
.idea/workspace.xml
|
10 |
.DS_Store
|
2_gpu.json
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"compute_environment": "LOCAL_MACHINE",
|
3 |
-
"distributed_type": "MULTI_GPU",
|
4 |
-
"fp16": false,
|
5 |
-
"machine_rank": 0,
|
6 |
-
"main_process_ip": null,
|
7 |
-
"main_process_port": null,
|
8 |
-
"main_training_function": "main",
|
9 |
-
"num_machines": 1,
|
10 |
-
"num_processes": 2
|
11 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3_gpu.json
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"compute_environment": "LOCAL_MACHINE",
|
3 |
-
"distributed_type": "MULTI_GPU",
|
4 |
-
"fp16": false,
|
5 |
-
"machine_rank": 0,
|
6 |
-
"main_process_ip": null,
|
7 |
-
"main_process_port": null,
|
8 |
-
"main_training_function": "main",
|
9 |
-
"num_machines": 1,
|
10 |
-
"num_processes": 3
|
11 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,3 +1,155 @@
|
|
1 |
-
# ELITE
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ELITE: Encoding Visual Concepts into Textual Embeddings for Customized Text-to-Image Generation
|
2 |
|
3 |
+
|
4 |
+
<a href="https://arxiv.org/pdf/2302.13848.pdf"><img src="https://img.shields.io/badge/arXiv-2302.13848-b31b1b.svg" height=22.5></a>
|
5 |
+
<a href="https://huggingface.co/spaces/ELITE-library/ELITE"><img src="https://img.shields.io/static/v1?label=HuggingFace&message=gradio demo&color=darkgreen" height=22.5></a>
|
6 |
+
|
7 |
+
## Getting Started
|
8 |
+
|
9 |
+
----
|
10 |
+
|
11 |
+
### Environment Setup
|
12 |
+
|
13 |
+
```shell
|
14 |
+
git clone https://github.com/csyxwei/ELITE.git
|
15 |
+
cd ELITE
|
16 |
+
conda create -n elite python=3.9
|
17 |
+
conda activate elite
|
18 |
+
pip install -r requirements.txt
|
19 |
+
```
|
20 |
+
|
21 |
+
### Pretrained Models
|
22 |
+
|
23 |
+
We provide the pretrained checkpoints in [Google Drive](https://drive.google.com/drive/folders/1VkiVZzA_i9gbfuzvHaLH2VYh7kOTzE0x?usp=sharing). One can download them and save to the directory `checkpoints`.
|
24 |
+
|
25 |
+
### Setting up Diffusers
|
26 |
+
|
27 |
+
Our code is built on the [diffusers](https://github.com/huggingface/diffusers/), and you can follow the guideline [here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#cat-toy-example) to set it.
|
28 |
+
|
29 |
+
### Customized Generation
|
30 |
+
|
31 |
+
We provide the testing dataset in [test_datasets](./test_datasets), which contains both images and object masks. For testing, you can run,
|
32 |
+
```
|
33 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
34 |
+
export DATA_DIR='./test_datasets/'
|
35 |
+
CUDA_VISIBLE_DEVICES=0 python inference_local.py \
|
36 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
37 |
+
--test_data_dir=$DATA_DIR \
|
38 |
+
--output_dir="./outputs/local_mapping" \
|
39 |
+
--suffix="object" \
|
40 |
+
--template="a photo of a S" \
|
41 |
+
--llambda="0.8" \
|
42 |
+
--global_mapper_path="./checkpoints/global_mapper.pt" \
|
43 |
+
--local_mapper_path="./checkpoints/local_mapper.pt"
|
44 |
+
```
|
45 |
+
or you can use the shell script:
|
46 |
+
```
|
47 |
+
bash inference_local.sh
|
48 |
+
```
|
49 |
+
If you want to test your customized dataset, you should align the image to ensure the object is at the center of image, and also provide the corresponding object mask. The object mask can be obtained by [image-matting-app](https://huggingface.co/spaces/SankarSrin/image-matting-app), or other image matting methods.
|
50 |
+
|
51 |
+
## Training
|
52 |
+
|
53 |
+
----
|
54 |
+
|
55 |
+
### Preparing Dataset
|
56 |
+
|
57 |
+
We use the **test** dataset of Open-Images V6 to train our ELITE. You can prepare the dataset as follows:
|
58 |
+
|
59 |
+
- Download Open-Images test dataset from [CVDF's site](https://github.com/cvdfoundation/open-images-dataset#download-images-with-bounding-boxes-annotations) and unzip it to the directory `datasets/Open_Images/images/test`.
|
60 |
+
- Download attribute names file `oidv6-attributes-description.csv` of Open-Images test dataset from [Open-Images official site](https://storage.googleapis.com/openimages/web/download_v7.html#download-manually) and save it to the directory `datasets/Open_Images/annotations/`.
|
61 |
+
- Download bbox annotations file `test-annotations-bbox.csv` of Open-Images test dataset from [Open-Images official site](https://storage.googleapis.com/openimages/web/download_v7.html#download-manually) and save it to the directory `datasets/Open_Images/annotations/`.
|
62 |
+
- Download segmentation annotations of Open-Images test dataset from [Open-Images official site](https://storage.googleapis.com/openimages/web/download_v7.html#download-manually) and unzip them to the directory `datasets/Open_Images/segs/test`. And put the `test-annotations-object-segmentation.csv` into `datasets/Open_Images/annotations/`.
|
63 |
+
- Obtain the mask bbox by running the following command:
|
64 |
+
```shell
|
65 |
+
python data_scripts/cal_bbox_by_seg.py
|
66 |
+
```
|
67 |
+
|
68 |
+
The final data structure is like this:
|
69 |
+
|
70 |
+
```
|
71 |
+
datasets
|
72 |
+
├── Open_Images
|
73 |
+
│ ├── annotations
|
74 |
+
│ │ ├── oidv6-class-descriptions.csv
|
75 |
+
│ │ ├── test-annotations-object-segmentation.csv
|
76 |
+
│ │ ├── test-annotations-bbox.csv
|
77 |
+
│ ├── images
|
78 |
+
│ │ ├── test
|
79 |
+
│ │ │ ├── xxx.jpg
|
80 |
+
│ │ │ ├── ...
|
81 |
+
│ ├── segs
|
82 |
+
│ │ ├── test
|
83 |
+
│ │ │ ├── xxx.png
|
84 |
+
│ │ │ ├── ...
|
85 |
+
│ │ ├── test_bbox_dict.npy
|
86 |
+
```
|
87 |
+
|
88 |
+
### Training Global Mapping Network
|
89 |
+
|
90 |
+
To train the global mapping network, run the following command:
|
91 |
+
|
92 |
+
```Shell
|
93 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
94 |
+
export DATA_DIR='./datasets/Open_Images/'
|
95 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25656 train_global.py \
|
96 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
97 |
+
--train_data_dir=$DATA_DIR \
|
98 |
+
--placeholder_token="S" \
|
99 |
+
--resolution=512 \
|
100 |
+
--train_batch_size=4 \
|
101 |
+
--gradient_accumulation_steps=4 \
|
102 |
+
--max_train_steps=200000 \
|
103 |
+
--learning_rate=1e-06 --scale_lr \
|
104 |
+
--lr_scheduler="constant" \
|
105 |
+
--lr_warmup_steps=0 \
|
106 |
+
--output_dir="./elite_experiments/global_mapping" \
|
107 |
+
--save_steps 200
|
108 |
+
```
|
109 |
+
or you can use the shell script:
|
110 |
+
```shell
|
111 |
+
bash train_global.sh
|
112 |
+
```
|
113 |
+
|
114 |
+
### Training Local Mapping Network
|
115 |
+
|
116 |
+
After the global mapping is trained, you can train the local mapping by running the following command:
|
117 |
+
|
118 |
+
```Shell
|
119 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
120 |
+
export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
|
121 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25657 train_local.py \
|
122 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
123 |
+
--train_data_dir=$DATA_DIR \
|
124 |
+
--placeholder_token="S" \
|
125 |
+
--resolution=512 \
|
126 |
+
--train_batch_size=2 \
|
127 |
+
--gradient_accumulation_steps=4 \
|
128 |
+
--max_train_steps=200000 \
|
129 |
+
--learning_rate=1e-5 --scale_lr \
|
130 |
+
--lr_scheduler="constant" \
|
131 |
+
--lr_warmup_steps=0 \
|
132 |
+
--global_mapper_path "./elite_experiments/global_mapping/mapper_070000.pt" \
|
133 |
+
--output_dir="./elite_experiments/local_mapping" \
|
134 |
+
--save_steps 200
|
135 |
+
```
|
136 |
+
or you can use the shell script:
|
137 |
+
```shell
|
138 |
+
bash train_local.sh
|
139 |
+
```
|
140 |
+
|
141 |
+
|
142 |
+
## Citation
|
143 |
+
|
144 |
+
```
|
145 |
+
@article{wei2023elite,
|
146 |
+
title={ELITE: Encoding Visual Concepts into Textual Embeddings for Customized Text-to-Image Generation},
|
147 |
+
author={Wei, Yuxiang and Zhang, Yabo and Ji, Zhilong and Bai, Jinfeng and Zhang, Lei and Zuo, Wangmeng},
|
148 |
+
journal={arXiv preprint arXiv:2302.13848},
|
149 |
+
year={2023}
|
150 |
+
}
|
151 |
+
```
|
152 |
+
|
153 |
+
## Acknowledgements
|
154 |
+
|
155 |
+
This code is built on [diffusers](https://github.com/huggingface/diffusers/). We thank the authors for sharing the codes.
|
data_scripts/cal_bbox_by_seg.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from os.path import join
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
dir = './datasets/Open_Images/'
|
8 |
+
mode = 'test'
|
9 |
+
|
10 |
+
image_dir = join(dir, 'images', mode)
|
11 |
+
seg_dir = join(dir, 'segs', mode)
|
12 |
+
|
13 |
+
files = os.listdir(seg_dir)
|
14 |
+
|
15 |
+
data_dict = {}
|
16 |
+
|
17 |
+
for file in tqdm(files):
|
18 |
+
seg_path = join(seg_dir, file)
|
19 |
+
image_path = join(image_dir, file.split('_')[0] + '.jpg')
|
20 |
+
seg = cv2.imread(seg_path)
|
21 |
+
image = cv2.imread(image_path)
|
22 |
+
seg = cv2.resize(seg, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
|
23 |
+
|
24 |
+
seg = seg[:, :, 0]
|
25 |
+
|
26 |
+
# obtain contours point set: contours
|
27 |
+
contours = cv2.findContours(seg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
28 |
+
contours = contours[0] if len(contours) == 2 else contours[1]
|
29 |
+
|
30 |
+
if len(contours) > 1:
|
31 |
+
cntr = np.vstack(contours)
|
32 |
+
elif len(contours) == 1:
|
33 |
+
cntr = contours[0]
|
34 |
+
else:
|
35 |
+
continue
|
36 |
+
|
37 |
+
if len(cntr) < 2:
|
38 |
+
continue
|
39 |
+
|
40 |
+
hs, he = np.min(cntr[:, :, 1]), np.max(cntr[:, :, 1])
|
41 |
+
ws, we = np.min(cntr[:, :, 0]), np.max(cntr[:, :, 0])
|
42 |
+
|
43 |
+
h, w = seg.shape
|
44 |
+
|
45 |
+
if (he - hs) % 2 == 1 and (he + 1) <= h:
|
46 |
+
he = he + 1
|
47 |
+
if (he - hs) % 2 == 1 and (hs - 1) >= 0:
|
48 |
+
hs = hs - 1
|
49 |
+
if (we - ws) % 2 == 1 and (we + 1) <= w:
|
50 |
+
we = we + 1
|
51 |
+
if (we - ws) % 2 == 1 and (ws - 1) >= 0:
|
52 |
+
ws = ws - 1
|
53 |
+
|
54 |
+
if he - hs < 2 or we - ws < 2:
|
55 |
+
continue
|
56 |
+
|
57 |
+
data_dict[file] = [cntr, hs, he, ws, we]
|
58 |
+
|
59 |
+
np.save(join(dir, 'segs', f'{mode}_bbox_dict.npy'), data_dict)
|
datasets.py
CHANGED
@@ -141,7 +141,9 @@ class CustomDatasetWithBG(Dataset):
|
|
141 |
image = Image.open(self.image_paths[i % self.num_images])
|
142 |
|
143 |
mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png'
|
144 |
-
mask = np.array(Image.open(mask_path))
|
|
|
|
|
145 |
|
146 |
if not image.mode == "RGB":
|
147 |
image = image.convert("RGB")
|
|
|
141 |
image = Image.open(self.image_paths[i % self.num_images])
|
142 |
|
143 |
mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png'
|
144 |
+
mask = np.array(Image.open(mask_path))
|
145 |
+
|
146 |
+
mask = np.where(mask > 0, 1, 0)
|
147 |
|
148 |
if not image.mode == "RGB":
|
149 |
image = image.convert("RGB")
|
elite.yaml
DELETED
@@ -1,147 +0,0 @@
|
|
1 |
-
name: elite
|
2 |
-
channels:
|
3 |
-
- defaults
|
4 |
-
dependencies:
|
5 |
-
- _libgcc_mutex=0.1=main
|
6 |
-
- ca-certificates=2022.10.11=h06a4308_0
|
7 |
-
- certifi=2022.9.24=py39h06a4308_0
|
8 |
-
- ld_impl_linux-64=2.38=h1181459_1
|
9 |
-
- libffi=3.3=he6710b0_2
|
10 |
-
- libgcc-ng=9.1.0=hdf63c60_0
|
11 |
-
- libstdcxx-ng=9.1.0=hdf63c60_0
|
12 |
-
- ncurses=6.3=h7f8727e_2
|
13 |
-
- openssl=1.1.1s=h7f8727e_0
|
14 |
-
- pip=22.2.2=py39h06a4308_0
|
15 |
-
- python=3.9.12=h12debd9_1
|
16 |
-
- readline=8.1.2=h7f8727e_1
|
17 |
-
- sqlite=3.38.5=hc218d9a_0
|
18 |
-
- tk=8.6.12=h1ccaba5_0
|
19 |
-
- wheel=0.37.1=pyhd3eb1b0_0
|
20 |
-
- xz=5.2.5=h7f8727e_1
|
21 |
-
- zlib=1.2.12=h7f8727e_2
|
22 |
-
- pip:
|
23 |
-
- absl-py==1.3.0
|
24 |
-
- accelerate==0.15.0
|
25 |
-
- aiohttp==3.8.3
|
26 |
-
- aiosignal==1.3.1
|
27 |
-
- albumentations==1.1.0
|
28 |
-
- altair==4.2.0
|
29 |
-
- antlr4-python3-runtime==4.8
|
30 |
-
- async-timeout==4.0.2
|
31 |
-
- attrs==22.1.0
|
32 |
-
- blinker==1.5
|
33 |
-
- cachetools==5.2.0
|
34 |
-
- charset-normalizer==2.1.1
|
35 |
-
- click==8.1.3
|
36 |
-
- commonmark==0.9.1
|
37 |
-
- contourpy==1.0.6
|
38 |
-
- cycler==0.11.0
|
39 |
-
- cython==0.29.33
|
40 |
-
- decorator==5.1.1
|
41 |
-
- diffusers==0.11.1
|
42 |
-
- einops==0.4.1
|
43 |
-
- emoji==2.2.0
|
44 |
-
- entrypoints==0.4
|
45 |
-
- faiss-gpu==1.7.2
|
46 |
-
- filelock==3.8.0
|
47 |
-
- fonttools==4.38.0
|
48 |
-
- frozenlist==1.3.3
|
49 |
-
- fsspec==2022.11.0
|
50 |
-
- ftfy==6.1.1
|
51 |
-
- future==0.18.2
|
52 |
-
- gitdb==4.0.9
|
53 |
-
- gitpython==3.1.29
|
54 |
-
- google-auth==2.14.1
|
55 |
-
- google-auth-oauthlib==0.4.6
|
56 |
-
- grpcio==1.50.0
|
57 |
-
- huggingface-hub==0.11.0
|
58 |
-
- idna==3.4
|
59 |
-
- imageio==2.14.1
|
60 |
-
- imageio-ffmpeg==0.4.7
|
61 |
-
- importlib-metadata==5.0.0
|
62 |
-
- jinja2==3.1.2
|
63 |
-
- joblib==1.2.0
|
64 |
-
- jsonschema==4.17.0
|
65 |
-
- kiwisolver==1.4.4
|
66 |
-
- kornia==0.6.0
|
67 |
-
- markdown==3.4.1
|
68 |
-
- markupsafe==2.1.1
|
69 |
-
- matplotlib==3.6.2
|
70 |
-
- multidict==6.0.2
|
71 |
-
- networkx==2.8.8
|
72 |
-
- nltk==3.7
|
73 |
-
- numpy==1.23.4
|
74 |
-
- oauthlib==3.2.2
|
75 |
-
- omegaconf==2.1.1
|
76 |
-
- opencv-python==4.6.0.66
|
77 |
-
- opencv-python-headless==4.6.0.66
|
78 |
-
- packaging==21.3
|
79 |
-
- pandas==1.5.1
|
80 |
-
- pillow==9.0.1
|
81 |
-
- protobuf==3.20.1
|
82 |
-
- psutil==5.9.4
|
83 |
-
- pudb==2019.2
|
84 |
-
- pyarrow==10.0.0
|
85 |
-
- pyasn1==0.4.8
|
86 |
-
- pyasn1-modules==0.2.8
|
87 |
-
- pycocotools==2.0.6
|
88 |
-
- pydeck==0.8.0
|
89 |
-
- pydensecrf==1.0rc2
|
90 |
-
- pydeprecate==0.3.2
|
91 |
-
- pygments==2.13.0
|
92 |
-
- pympler==1.0.1
|
93 |
-
- pyparsing==3.0.9
|
94 |
-
- pyrsistent==0.19.2
|
95 |
-
- python-dateutil==2.8.2
|
96 |
-
- python-dotenv==0.21.0
|
97 |
-
- pytorch-lightning==1.6.5
|
98 |
-
- pytz==2022.6
|
99 |
-
- pytz-deprecation-shim==0.1.0.post0
|
100 |
-
- pywavelets==1.4.1
|
101 |
-
- pyyaml==6.0
|
102 |
-
- qudida==0.0.4
|
103 |
-
- regex==2022.10.31
|
104 |
-
- requests==2.28.1
|
105 |
-
- requests-oauthlib==1.3.1
|
106 |
-
- rich==12.6.0
|
107 |
-
- rsa==4.9
|
108 |
-
- sacremoses==0.0.53
|
109 |
-
- scikit-image==0.19.3
|
110 |
-
- scikit-learn==1.1.3
|
111 |
-
- scipy==1.9.3
|
112 |
-
- semver==2.13.0
|
113 |
-
- setuptools==59.5.0
|
114 |
-
- six==1.16.0
|
115 |
-
- smmap==5.0.0
|
116 |
-
- stanza==1.4.2
|
117 |
-
- streamlit==1.15.0
|
118 |
-
- tensorboard==2.11.0
|
119 |
-
- tensorboard-data-server==0.6.1
|
120 |
-
- tensorboard-plugin-wit==1.8.1
|
121 |
-
- test-tube==0.7.5
|
122 |
-
- threadpoolctl==3.1.0
|
123 |
-
- tifffile==2022.10.10
|
124 |
-
- timm==0.6.12
|
125 |
-
- tokenizers==0.12.1
|
126 |
-
- toml==0.10.2
|
127 |
-
- toolz==0.12.0
|
128 |
-
- torch==1.12.1+cu116
|
129 |
-
- torch-fidelity==0.3.0
|
130 |
-
- torchaudio==0.12.1+cu116
|
131 |
-
- torchmetrics==0.6.0
|
132 |
-
- torchvision==0.13.1+cu116
|
133 |
-
- tornado==6.2
|
134 |
-
- tqdm==4.64.1
|
135 |
-
- transformers==4.25.1
|
136 |
-
- typing-extensions==4.4.0
|
137 |
-
- tzdata==2022.6
|
138 |
-
- tzlocal==4.2
|
139 |
-
- urllib3==1.26.12
|
140 |
-
- urwid==2.1.2
|
141 |
-
- validators==0.20.0
|
142 |
-
- watchdog==2.1.9
|
143 |
-
- wcwidth==0.2.5
|
144 |
-
- werkzeug==2.2.2
|
145 |
-
- yarl==1.8.1
|
146 |
-
- zipp==3.10.0
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_global.py
CHANGED
@@ -170,6 +170,12 @@ def parse_args():
|
|
170 |
help="Data index. -1 for all.",
|
171 |
)
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
args = parser.parse_args()
|
174 |
return args
|
175 |
|
@@ -204,11 +210,7 @@ if __name__ == "__main__":
|
|
204 |
batch["input_ids"] = batch["input_ids"].to("cuda:0")
|
205 |
batch["index"] = batch["index"].to("cuda:0").long()
|
206 |
print(step, batch['text'])
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5,
|
212 |
-
token_index=args.token_index, seed=seed)
|
213 |
-
concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
|
214 |
-
Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
|
|
|
170 |
help="Data index. -1 for all.",
|
171 |
)
|
172 |
|
173 |
+
parser.add_argument(
|
174 |
+
"--seed",
|
175 |
+
type=int,
|
176 |
+
default=None,
|
177 |
+
help="A seed for testing.",
|
178 |
+
)
|
179 |
args = parser.parse_args()
|
180 |
return args
|
181 |
|
|
|
210 |
batch["input_ids"] = batch["input_ids"].to("cuda:0")
|
211 |
batch["index"] = batch["index"].to("cuda:0").long()
|
212 |
print(step, batch['text'])
|
213 |
+
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5,
|
214 |
+
token_index=args.token_index, seed=args.seed)
|
215 |
+
concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
|
216 |
+
Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(args.seed).zfill(5)}.jpg'))
|
|
|
|
|
|
|
|
inference_global.sh
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
export DATA_DIR='./test_datasets/'
|
3 |
|
4 |
-
CUDA_VISIBLE_DEVICES=
|
5 |
--pretrained_model_name_or_path=$MODEL_NAME \
|
6 |
--test_data_dir=$DATA_DIR \
|
7 |
--output_dir="./outputs/global_mapping" \
|
8 |
--suffix="object" \
|
9 |
--token_index="0" \
|
10 |
-
--template="a photo of a
|
11 |
-
--global_mapper_path="./checkpoints/global_mapper.pt"
|
|
|
12 |
|
|
|
1 |
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
export DATA_DIR='./test_datasets/'
|
3 |
|
4 |
+
CUDA_VISIBLE_DEVICES=7 python inference_global.py \
|
5 |
--pretrained_model_name_or_path=$MODEL_NAME \
|
6 |
--test_data_dir=$DATA_DIR \
|
7 |
--output_dir="./outputs/global_mapping" \
|
8 |
--suffix="object" \
|
9 |
--token_index="0" \
|
10 |
+
--template="a photo of a S" \
|
11 |
+
--global_mapper_path="./checkpoints/global_mapper.pt" \
|
12 |
+
--seed 42
|
13 |
|
inference_local.py
CHANGED
@@ -199,6 +199,13 @@ def parse_args():
|
|
199 |
help="Lambda for fuse the global and local feature.",
|
200 |
)
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
args = parser.parse_args()
|
203 |
return args
|
204 |
|
@@ -236,12 +243,8 @@ if __name__ == "__main__":
|
|
236 |
batch["input_ids"] = batch["input_ids"].to("cuda:0")
|
237 |
batch["index"] = batch["index"].to("cuda:0").long()
|
238 |
print(step, batch['text'])
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
batch["pixel_values_clip"].device, 5,
|
245 |
-
seed=seed, llambda=float(args.llambda))
|
246 |
-
concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
|
247 |
-
Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
|
|
|
199 |
help="Lambda for fuse the global and local feature.",
|
200 |
)
|
201 |
|
202 |
+
parser.add_argument(
|
203 |
+
"--seed",
|
204 |
+
type=int,
|
205 |
+
default=None,
|
206 |
+
help="A seed for testing.",
|
207 |
+
)
|
208 |
+
|
209 |
args = parser.parse_args()
|
210 |
return args
|
211 |
|
|
|
243 |
batch["input_ids"] = batch["input_ids"].to("cuda:0")
|
244 |
batch["index"] = batch["index"].to("cuda:0").long()
|
245 |
print(step, batch['text'])
|
246 |
+
syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae,
|
247 |
+
batch["pixel_values_clip"].device, 5,
|
248 |
+
seed=args.seed, llambda=float(args.llambda))
|
249 |
+
concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
|
250 |
+
Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(args.seed).zfill(5)}.jpg'))
|
|
|
|
|
|
|
|
inference_local.sh
CHANGED
@@ -5,8 +5,9 @@ CUDA_VISIBLE_DEVICES=7 python inference_local.py \
|
|
5 |
--test_data_dir=$DATA_DIR \
|
6 |
--output_dir="./outputs/local_mapping" \
|
7 |
--suffix="object" \
|
8 |
-
--template="a photo of a
|
9 |
--llambda="0.8" \
|
10 |
--global_mapper_path="./checkpoints/global_mapper.pt" \
|
11 |
-
--local_mapper_path="./checkpoints/local_mapper.pt"
|
|
|
12 |
|
|
|
5 |
--test_data_dir=$DATA_DIR \
|
6 |
--output_dir="./outputs/local_mapping" \
|
7 |
--suffix="object" \
|
8 |
+
--template="a photo of a S" \
|
9 |
--llambda="0.8" \
|
10 |
--global_mapper_path="./checkpoints/global_mapper.pt" \
|
11 |
+
--local_mapper_path="./checkpoints/local_mapper.pt" \
|
12 |
+
--seed 42
|
13 |
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.16.0
|
2 |
+
albumentations==1.3.0
|
3 |
+
diffusers==0.11.1
|
4 |
+
gradio==3.20.1
|
5 |
+
huggingface-hub==0.13.0
|
6 |
+
opencv-python-headless==4.7.0.68
|
7 |
+
Pillow==9.4.0
|
8 |
+
torch==1.13.1
|
9 |
+
torchvision==0.14.1
|
10 |
+
tqdm==4.65.0
|
11 |
+
transformers==4.26.1
|
train_global.py
CHANGED
@@ -11,7 +11,6 @@ import torch.nn.functional as F
|
|
11 |
import torch.utils.checkpoint
|
12 |
from torch.utils.data import Dataset
|
13 |
|
14 |
-
import PIL
|
15 |
from accelerate import Accelerator
|
16 |
from accelerate.logging import get_logger
|
17 |
from accelerate.utils import set_seed
|
@@ -31,7 +30,7 @@ from PIL import Image
|
|
31 |
from tqdm.auto import tqdm
|
32 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
33 |
|
34 |
-
from typing import
|
35 |
from datasets import OpenImagesDataset
|
36 |
|
37 |
|
@@ -362,7 +361,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|
362 |
else:
|
363 |
return f"{organization}/{model_id}"
|
364 |
|
365 |
-
|
366 |
def freeze_params(params):
|
367 |
for param in params:
|
368 |
param.requires_grad = False
|
|
|
11 |
import torch.utils.checkpoint
|
12 |
from torch.utils.data import Dataset
|
13 |
|
|
|
14 |
from accelerate import Accelerator
|
15 |
from accelerate.logging import get_logger
|
16 |
from accelerate.utils import set_seed
|
|
|
30 |
from tqdm.auto import tqdm
|
31 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
32 |
|
33 |
+
from typing import Optional, Tuple, Union
|
34 |
from datasets import OpenImagesDataset
|
35 |
|
36 |
|
|
|
361 |
else:
|
362 |
return f"{organization}/{model_id}"
|
363 |
|
|
|
364 |
def freeze_params(params):
|
365 |
for param in params:
|
366 |
param.requires_grad = False
|
train_global.sh
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
-
export DATA_DIR='
|
3 |
-
CUDA_VISIBLE_DEVICES=
|
4 |
--pretrained_model_name_or_path=$MODEL_NAME \
|
5 |
--train_data_dir=$DATA_DIR \
|
6 |
--placeholder_token="S" \
|
@@ -11,5 +11,5 @@ CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file 4_gpu.json --main_p
|
|
11 |
--learning_rate=1e-06 --scale_lr \
|
12 |
--lr_scheduler="constant" \
|
13 |
--lr_warmup_steps=0 \
|
14 |
-
--output_dir="./elite_experiments/
|
15 |
--save_steps 200
|
|
|
1 |
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
+
export DATA_DIR='./datasets/Open_Images/'
|
3 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25656 train_global.py \
|
4 |
--pretrained_model_name_or_path=$MODEL_NAME \
|
5 |
--train_data_dir=$DATA_DIR \
|
6 |
--placeholder_token="S" \
|
|
|
11 |
--learning_rate=1e-06 --scale_lr \
|
12 |
--lr_scheduler="constant" \
|
13 |
--lr_warmup_steps=0 \
|
14 |
+
--output_dir="./elite_experiments/global_mapping_new" \
|
15 |
--save_steps 200
|
train_local.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import argparse
|
3 |
import itertools
|
4 |
import math
|
@@ -16,15 +15,13 @@ import PIL
|
|
16 |
from accelerate import Accelerator
|
17 |
from accelerate.logging import get_logger
|
18 |
from accelerate.utils import set_seed
|
19 |
-
from diffusers import AutoencoderKL, DDPMScheduler,
|
20 |
from diffusers.optimization import get_scheduler
|
21 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
22 |
from huggingface_hub import HfFolder, Repository, whoami
|
23 |
|
24 |
-
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
25 |
from PIL import Image
|
26 |
from tqdm.auto import tqdm
|
27 |
-
from transformers import
|
28 |
|
29 |
|
30 |
from typing import Optional
|
|
|
|
|
1 |
import argparse
|
2 |
import itertools
|
3 |
import math
|
|
|
15 |
from accelerate import Accelerator
|
16 |
from accelerate.logging import get_logger
|
17 |
from accelerate.utils import set_seed
|
18 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, LMSDiscreteScheduler
|
19 |
from diffusers.optimization import get_scheduler
|
|
|
20 |
from huggingface_hub import HfFolder, Repository, whoami
|
21 |
|
|
|
22 |
from PIL import Image
|
23 |
from tqdm.auto import tqdm
|
24 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
25 |
|
26 |
|
27 |
from typing import Optional
|
train_local.sh
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
-
export DATA_DIR='
|
3 |
-
CUDA_VISIBLE_DEVICES=
|
4 |
--pretrained_model_name_or_path=$MODEL_NAME \
|
5 |
--train_data_dir=$DATA_DIR \
|
6 |
--placeholder_token="S" \
|
|
|
1 |
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
2 |
+
export DATA_DIR='./datasets/Open_Images/'
|
3 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file 4_gpu.json --main_process_port 25657 train_local.py \
|
4 |
--pretrained_model_name_or_path=$MODEL_NAME \
|
5 |
--train_data_dir=$DATA_DIR \
|
6 |
--placeholder_token="S" \
|