goodmodeler commited on
Commit
be305fb
·
1 Parent(s): 7066d20

load dataset

Browse files
README.md CHANGED
@@ -14,6 +14,8 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
14
 
15
  commands:
16
 
 
 
17
  pip install git+https://github.com/huggingface/diffusers
18
 
19
  accelerate launch \
 
14
 
15
  commands:
16
 
17
+ download images: python download.py -i 1 -r 2 -o /home/user/app/image_tmp -z
18
+
19
  pip install git+https://github.com/huggingface/diffusers
20
 
21
  accelerate launch \
image_download.py → deprecated/image_download.py RENAMED
File without changes
image_gen.py → deprecated/image_gen.py RENAMED
File without changes
download.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Marco Lustri 2022 - https://github.com/TheLustriVA
2
+ # MIT License
3
+
4
+ """A script to make downloading the DiffusionDB dataset easier."""
5
+ from urllib.error import HTTPError
6
+ from urllib.request import urlretrieve
7
+ from alive_progress import alive_bar
8
+ from os.path import exists
9
+
10
+ import shutil
11
+ import os
12
+ import time
13
+ import argparse
14
+
15
+ index = None # initiate main arguments as None
16
+ range_max = None
17
+ output = None
18
+ unzip = None
19
+ large = None
20
+
21
+ parser = argparse.ArgumentParser(description="Download a file from a URL") #
22
+
23
+ # It's adding arguments to the parser.
24
+ parser.add_argument(
25
+ "-i",
26
+ "--index",
27
+ type=int,
28
+ default=1,
29
+ help="File to download or lower bound of range if -r is set",
30
+ )
31
+ parser.add_argument(
32
+ "-r",
33
+ "--range",
34
+ type=int,
35
+ default=None,
36
+ help="Upper bound of range if -i is provided",
37
+ )
38
+ parser.add_argument(
39
+ "-o", "--output", type=str, default="images", help="Output directory name"
40
+ )
41
+ parser.add_argument(
42
+ "-z",
43
+ "--unzip",
44
+ default=False,
45
+ help="Unzip the file after downloading",
46
+ # It's setting the argument to True if it's provided.
47
+ action="store_true",
48
+ )
49
+ parser.add_argument(
50
+ "-l",
51
+ "--large",
52
+ default=False,
53
+ help="Download from DiffusionDB Large (14 million images)",
54
+ action="store_true",
55
+ )
56
+
57
+ args = parser.parse_args() # parse the arguments
58
+
59
+ # It's checking if the user has provided any arguments, and if they have, it
60
+ # sets the variables to the arguments.
61
+ if args.index:
62
+ index = args.index
63
+ if args.range:
64
+ range_max = args.range
65
+ if args.output:
66
+ output = args.output
67
+ if args.unzip:
68
+ unzip = args.unzip
69
+ if args.large:
70
+ large = args.large
71
+
72
+
73
+
74
+ def download(index=1, range_index=0, output="", large=False):
75
+ """
76
+ Download a file from a URL and save it to a local file
77
+
78
+ :param index: The index of the file to download, defaults to 1 (optional)
79
+ :param range_index: The number of files to download. If you want to download
80
+ all files, set this to the number of files you want to download,
81
+ defaults to 0 (optional)
82
+ :param output: The directory to download the files to :return: A list of
83
+ files to unzip
84
+ :param large: If downloading from DiffusionDB Large (14 million images)
85
+ instead of DiffusionDB 2M (2 million images)
86
+ """
87
+ baseurl = "https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/"
88
+ files_to_unzip = []
89
+
90
+ if large:
91
+ if index <= 10000:
92
+ url = f"{baseurl}diffusiondb-large-part-1/part-{index:06}.zip"
93
+ else:
94
+ url = f"{baseurl}diffusiondb-large-part-2/part-{index:06}.zip"
95
+ else:
96
+ url = f"{baseurl}images/part-{index:06}.zip"
97
+
98
+ if output != "":
99
+ output = f"{output}/"
100
+
101
+ if not exists(output):
102
+ os.makedirs(output)
103
+
104
+ if range_index == 0:
105
+ print("Downloading file: ", url)
106
+ file_path = f"{output}part-{index:06}.zip"
107
+ try:
108
+ urlretrieve(url, file_path)
109
+ except HTTPError as e:
110
+ print(f"Encountered an HTTPError downloading file: {url} - {e}")
111
+ if unzip:
112
+ unzip(file_path)
113
+ else:
114
+ # It's downloading the files numbered from index to range_index.
115
+ with alive_bar(range_index - index, title="Downloading files") as bar:
116
+ for idx in range(index, range_index):
117
+ if large:
118
+ if idx <= 10000:
119
+ url = f"{baseurl}diffusiondb-large-part-1/part-{idx:06}.zip"
120
+ else:
121
+ url = f"{baseurl}diffusiondb-large-part-2/part-{idx:06}.zip"
122
+ else:
123
+ url = f"{baseurl}images/part-{idx:06}.zip"
124
+
125
+ loop_file_path = f"{output}part-{idx:06}.zip"
126
+ # It's trying to download the file, and if it encounters an
127
+ # HTTPError, it prints the error.
128
+ try:
129
+ urlretrieve(url, loop_file_path)
130
+ except HTTPError as e:
131
+ print(f"HTTPError downloading file: {url} - {e}")
132
+ files_to_unzip.append(loop_file_path)
133
+ # It's writing the url of the file to a manifest file.
134
+ with open("manifest.txt", "a") as f:
135
+ f.write(url + "\n")
136
+ time.sleep(0.1)
137
+ bar()
138
+
139
+ # It's checking if the user wants to unzip the files, and if they do, it
140
+ # returns a list of files to unzip. It would be a bad idea to put these
141
+ # together as the process is already lengthy.
142
+ if unzip and len(files_to_unzip) > 0:
143
+ return files_to_unzip
144
+
145
+
146
+ def unzip_file(file: str, extract_to: str = None):
147
+ """
148
+ > This function takes a zip file and unpacks it to specified directory
149
+
150
+ :param file: str - path to zip file
151
+ :param extract_to: str - directory to extract to (default: same name as zip file)
152
+ :return: The extraction directory path
153
+ """
154
+ if extract_to is None:
155
+ extract_to = file.replace('.zip', '')
156
+
157
+ shutil.unpack_archive(file, extract_to)
158
+ return f"File: {file} has been unzipped to {extract_to}"
159
+
160
+
161
+ def unzip_all(files: list):
162
+ """
163
+ > Unzip all files in a list of files
164
+
165
+ :param files: list
166
+ :type files: list
167
+ """
168
+ with alive_bar(len(files), title="Unzipping files") as bar:
169
+ for file in files:
170
+ unzip_file(file, '/home/user/app/images')
171
+ time.sleep(0.1)
172
+ bar()
173
+
174
+
175
+ def main(index=None, range_max=None, output=None, unzip=None, large=None):
176
+ """
177
+ `main` is a function that takes in an index, a range_max, an output, and an
178
+ unzip, and if the user confirms that they have enough space, it downloads
179
+ the files from the index to the output, and if unzip is true, it unzips them
180
+
181
+ :param index: The index of the file you want to download
182
+ :param range_max: The number of files to download
183
+ :param output: The directory to download the files to
184
+ :param unzip: If you want to unzip the files after downloading them, set
185
+ this to True
186
+ :param large: If you want to download from DiffusionDB Large (14 million
187
+ images) instead of DiffusionDB 2M (2 million images)
188
+ :return: A list of files that have been downloaded
189
+ """
190
+ if index and range_max:
191
+ if range_max - index >= 1999:
192
+ confirmation = input("Do you have at least 1.7Tb free: (y/n)")
193
+ if confirmation != "y":
194
+ return
195
+ files = download(index, range_max, output, large)
196
+ if unzip:
197
+ unzip_all(files)
198
+ elif index:
199
+ download(index, output=output, large=large)
200
+ else:
201
+ print("No index provided")
202
+
203
+
204
+ # This is a common pattern in Python. It allows you to run the main function of
205
+ # your script by running the script through the interpreter. It also allows you
206
+ # to import the script into the interpreter without automatically running the
207
+ # main function.
208
+ if __name__ == "__main__":
209
+ main(index, range_max, output, unzip, large)
download_dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from datasets import load_dataset
5
+ from PIL import Image
6
+ import shutil
7
+ from tqdm import tqdm
8
+
9
+ def load_and_process():
10
+ dataset = load_dataset("poloclub/diffusiondb", split="train[:1000]")
11
+
12
+ os.makedirs("processed/images", exist_ok=True)
13
+ processed_data = []
14
+
15
+ for idx, sample in enumerate(tqdm(dataset)):
16
+ image_id = f"{idx:06d}.png"
17
+
18
+ if sample.get('image'):
19
+ sample['image'].save(f"processed/images/{image_id}")
20
+
21
+ data_entry = {
22
+ "id": idx,
23
+ "image_file": image_id,
24
+ "prompt": sample.get('p', ''),
25
+ "seed": sample.get('se', 0),
26
+ "cfg_scale": sample.get('c', 0.0),
27
+ "steps": sample.get('st', 0),
28
+ "sampler": sample.get('sa', '')
29
+ }
30
+ processed_data.append(data_entry)
31
+
32
+ return processed_data
33
+
34
+ def save_data(data):
35
+ with open("processed/data.json", "w") as f:
36
+ json.dump(data, f)
37
+
38
+ df = pd.DataFrame(data)
39
+ df.to_csv("processed/data.csv", index=False)
40
+ df.to_parquet("processed/data.parquet", index=False)
41
+
42
+ def main():
43
+ data = load_and_process()
44
+ save_data(data)
45
+ print(f"Processed {len(data)} samples")
46
+
47
+ if __name__ == "__main__":
48
+ main()
preprocess_data.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from datasets import load_dataset
4
+ from PIL import Image, ImageOps, ImageFilter
5
+ from tqdm import tqdm
6
+ import random
7
+
8
+ def preprocess_image(image, target_size=512, quality_threshold=0.7):
9
+ """Preprocess image with various enhancements"""
10
+ # Convert to RGB if needed
11
+ if image.mode != 'RGB':
12
+ image = image.convert('RGB')
13
+
14
+ # Filter out low quality images
15
+ width, height = image.size
16
+ if min(width, height) < target_size * quality_threshold:
17
+ return None
18
+
19
+ # Center crop to square if not already
20
+ if width != height:
21
+ size = min(width, height)
22
+ left = (width - size) // 2
23
+ top = (height - size) // 2
24
+ image = image.crop((left, top, left + size, top + size))
25
+
26
+ # Resize to target size
27
+ image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
28
+
29
+ # Enhance image quality
30
+ # Slightly sharpen
31
+ image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3))
32
+
33
+ # Auto-adjust levels
34
+ image = ImageOps.autocontrast(image, cutoff=1)
35
+
36
+ return image
37
+
38
+ def clean_prompt(prompt):
39
+ """Clean and normalize prompts"""
40
+ if not prompt:
41
+ return ""
42
+
43
+ # Remove excessive whitespace
44
+ prompt = ' '.join(prompt.split())
45
+
46
+ # Remove common artifacts
47
+ prompt = prompt.replace(' ', ' ')
48
+ prompt = prompt.strip(' .,;:')
49
+
50
+ # Filter out very short or very long prompts
51
+ words = prompt.split()
52
+ if len(words) < 3 or len(words) > 50:
53
+ return None
54
+
55
+ return prompt
56
+
57
+ def prepare_dreambooth_data():
58
+ # Load dataset
59
+ dataset = load_dataset('poloclub/diffusiondb', 'large_random_1k')
60
+ train_data = dataset['train']
61
+
62
+ # Create directory structure
63
+ data_dir = "./diffusiondb_dataset"
64
+ os.makedirs(data_dir, exist_ok=True)
65
+
66
+ valid_samples = 0
67
+
68
+ # Process images with preprocessing
69
+ for idx, sample in enumerate(tqdm(train_data, desc="Processing images")):
70
+ # Preprocess image
71
+ image = preprocess_image(sample['image'])
72
+ if image is None:
73
+ continue
74
+
75
+ # Clean prompt
76
+ prompt = clean_prompt(sample.get('prompt', ''))
77
+ if prompt is None:
78
+ continue
79
+
80
+ # Save processed image
81
+ image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg")
82
+ image.save(image_path, "JPEG", quality=95, optimize=True)
83
+
84
+ # Save cleaned caption
85
+ caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt")
86
+ with open(caption_path, 'w', encoding='utf-8') as f:
87
+ f.write(prompt)
88
+
89
+ valid_samples += 1
90
+
91
+ print(f"Processed {len(train_data)} samples, saved {valid_samples} valid images to {data_dir}")
92
+ return data_dir
93
+
94
+ # Convert dataset
95
+ data_dir = prepare_dreambooth_data()
96
+
97
+ # Now you can use the standard accelerate command:
98
+ training_command = f"""
99
+ accelerate launch \\
100
+ --deepspeed_config_file ds_config.json \\
101
+ diffusers/examples/dreambooth/train_dreambooth.py \\
102
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\
103
+ --instance_data_dir="{data_dir}" \\
104
+ --instance_prompt="a generated image" \\
105
+ --output_dir="./diffusiondb-model" \\
106
+ --resolution=512 \\
107
+ --train_batch_size=1 \\
108
+ --gradient_accumulation_steps=1 \\
109
+ --gradient_checkpointing \\
110
+ --learning_rate=5e-6 \\
111
+ --lr_scheduler="constant" \\
112
+ --lr_warmup_steps=0 \\
113
+ --max_train_steps=400 \\
114
+ --mixed_precision="fp16" \\
115
+ --checkpointing_steps=100 \\
116
+ --checkpoints_total_limit=1 \\
117
+ --report_to="tensorboard" \\
118
+ --logging_dir="./diffusiondb-model/logs"
119
+ """
120
+
121
+ print("Run this command:")
122
+ print(training_command)
requirements.txt CHANGED
@@ -13,4 +13,5 @@ faiss-cpu
13
  sentence-transformers
14
  trl[peft]
15
  label-studio
16
- datasets
 
 
13
  sentence-transformers
14
  trl[peft]
15
  label-studio
16
+ datasets
17
+ alive_progress