Spaces:
Sleeping
Sleeping
Commit
·
be305fb
1
Parent(s):
7066d20
load dataset
Browse files- README.md +2 -0
- image_download.py → deprecated/image_download.py +0 -0
- image_gen.py → deprecated/image_gen.py +0 -0
- download.py +209 -0
- download_dataset.py +48 -0
- preprocess_data.py +122 -0
- requirements.txt +2 -1
README.md
CHANGED
@@ -14,6 +14,8 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
|
|
14 |
|
15 |
commands:
|
16 |
|
|
|
|
|
17 |
pip install git+https://github.com/huggingface/diffusers
|
18 |
|
19 |
accelerate launch \
|
|
|
14 |
|
15 |
commands:
|
16 |
|
17 |
+
download images: python download.py -i 1 -r 2 -o /home/user/app/image_tmp -z
|
18 |
+
|
19 |
pip install git+https://github.com/huggingface/diffusers
|
20 |
|
21 |
accelerate launch \
|
image_download.py → deprecated/image_download.py
RENAMED
File without changes
|
image_gen.py → deprecated/image_gen.py
RENAMED
File without changes
|
download.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: Marco Lustri 2022 - https://github.com/TheLustriVA
|
2 |
+
# MIT License
|
3 |
+
|
4 |
+
"""A script to make downloading the DiffusionDB dataset easier."""
|
5 |
+
from urllib.error import HTTPError
|
6 |
+
from urllib.request import urlretrieve
|
7 |
+
from alive_progress import alive_bar
|
8 |
+
from os.path import exists
|
9 |
+
|
10 |
+
import shutil
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
index = None # initiate main arguments as None
|
16 |
+
range_max = None
|
17 |
+
output = None
|
18 |
+
unzip = None
|
19 |
+
large = None
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(description="Download a file from a URL") #
|
22 |
+
|
23 |
+
# It's adding arguments to the parser.
|
24 |
+
parser.add_argument(
|
25 |
+
"-i",
|
26 |
+
"--index",
|
27 |
+
type=int,
|
28 |
+
default=1,
|
29 |
+
help="File to download or lower bound of range if -r is set",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"-r",
|
33 |
+
"--range",
|
34 |
+
type=int,
|
35 |
+
default=None,
|
36 |
+
help="Upper bound of range if -i is provided",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"-o", "--output", type=str, default="images", help="Output directory name"
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"-z",
|
43 |
+
"--unzip",
|
44 |
+
default=False,
|
45 |
+
help="Unzip the file after downloading",
|
46 |
+
# It's setting the argument to True if it's provided.
|
47 |
+
action="store_true",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"-l",
|
51 |
+
"--large",
|
52 |
+
default=False,
|
53 |
+
help="Download from DiffusionDB Large (14 million images)",
|
54 |
+
action="store_true",
|
55 |
+
)
|
56 |
+
|
57 |
+
args = parser.parse_args() # parse the arguments
|
58 |
+
|
59 |
+
# It's checking if the user has provided any arguments, and if they have, it
|
60 |
+
# sets the variables to the arguments.
|
61 |
+
if args.index:
|
62 |
+
index = args.index
|
63 |
+
if args.range:
|
64 |
+
range_max = args.range
|
65 |
+
if args.output:
|
66 |
+
output = args.output
|
67 |
+
if args.unzip:
|
68 |
+
unzip = args.unzip
|
69 |
+
if args.large:
|
70 |
+
large = args.large
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def download(index=1, range_index=0, output="", large=False):
|
75 |
+
"""
|
76 |
+
Download a file from a URL and save it to a local file
|
77 |
+
|
78 |
+
:param index: The index of the file to download, defaults to 1 (optional)
|
79 |
+
:param range_index: The number of files to download. If you want to download
|
80 |
+
all files, set this to the number of files you want to download,
|
81 |
+
defaults to 0 (optional)
|
82 |
+
:param output: The directory to download the files to :return: A list of
|
83 |
+
files to unzip
|
84 |
+
:param large: If downloading from DiffusionDB Large (14 million images)
|
85 |
+
instead of DiffusionDB 2M (2 million images)
|
86 |
+
"""
|
87 |
+
baseurl = "https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/"
|
88 |
+
files_to_unzip = []
|
89 |
+
|
90 |
+
if large:
|
91 |
+
if index <= 10000:
|
92 |
+
url = f"{baseurl}diffusiondb-large-part-1/part-{index:06}.zip"
|
93 |
+
else:
|
94 |
+
url = f"{baseurl}diffusiondb-large-part-2/part-{index:06}.zip"
|
95 |
+
else:
|
96 |
+
url = f"{baseurl}images/part-{index:06}.zip"
|
97 |
+
|
98 |
+
if output != "":
|
99 |
+
output = f"{output}/"
|
100 |
+
|
101 |
+
if not exists(output):
|
102 |
+
os.makedirs(output)
|
103 |
+
|
104 |
+
if range_index == 0:
|
105 |
+
print("Downloading file: ", url)
|
106 |
+
file_path = f"{output}part-{index:06}.zip"
|
107 |
+
try:
|
108 |
+
urlretrieve(url, file_path)
|
109 |
+
except HTTPError as e:
|
110 |
+
print(f"Encountered an HTTPError downloading file: {url} - {e}")
|
111 |
+
if unzip:
|
112 |
+
unzip(file_path)
|
113 |
+
else:
|
114 |
+
# It's downloading the files numbered from index to range_index.
|
115 |
+
with alive_bar(range_index - index, title="Downloading files") as bar:
|
116 |
+
for idx in range(index, range_index):
|
117 |
+
if large:
|
118 |
+
if idx <= 10000:
|
119 |
+
url = f"{baseurl}diffusiondb-large-part-1/part-{idx:06}.zip"
|
120 |
+
else:
|
121 |
+
url = f"{baseurl}diffusiondb-large-part-2/part-{idx:06}.zip"
|
122 |
+
else:
|
123 |
+
url = f"{baseurl}images/part-{idx:06}.zip"
|
124 |
+
|
125 |
+
loop_file_path = f"{output}part-{idx:06}.zip"
|
126 |
+
# It's trying to download the file, and if it encounters an
|
127 |
+
# HTTPError, it prints the error.
|
128 |
+
try:
|
129 |
+
urlretrieve(url, loop_file_path)
|
130 |
+
except HTTPError as e:
|
131 |
+
print(f"HTTPError downloading file: {url} - {e}")
|
132 |
+
files_to_unzip.append(loop_file_path)
|
133 |
+
# It's writing the url of the file to a manifest file.
|
134 |
+
with open("manifest.txt", "a") as f:
|
135 |
+
f.write(url + "\n")
|
136 |
+
time.sleep(0.1)
|
137 |
+
bar()
|
138 |
+
|
139 |
+
# It's checking if the user wants to unzip the files, and if they do, it
|
140 |
+
# returns a list of files to unzip. It would be a bad idea to put these
|
141 |
+
# together as the process is already lengthy.
|
142 |
+
if unzip and len(files_to_unzip) > 0:
|
143 |
+
return files_to_unzip
|
144 |
+
|
145 |
+
|
146 |
+
def unzip_file(file: str, extract_to: str = None):
|
147 |
+
"""
|
148 |
+
> This function takes a zip file and unpacks it to specified directory
|
149 |
+
|
150 |
+
:param file: str - path to zip file
|
151 |
+
:param extract_to: str - directory to extract to (default: same name as zip file)
|
152 |
+
:return: The extraction directory path
|
153 |
+
"""
|
154 |
+
if extract_to is None:
|
155 |
+
extract_to = file.replace('.zip', '')
|
156 |
+
|
157 |
+
shutil.unpack_archive(file, extract_to)
|
158 |
+
return f"File: {file} has been unzipped to {extract_to}"
|
159 |
+
|
160 |
+
|
161 |
+
def unzip_all(files: list):
|
162 |
+
"""
|
163 |
+
> Unzip all files in a list of files
|
164 |
+
|
165 |
+
:param files: list
|
166 |
+
:type files: list
|
167 |
+
"""
|
168 |
+
with alive_bar(len(files), title="Unzipping files") as bar:
|
169 |
+
for file in files:
|
170 |
+
unzip_file(file, '/home/user/app/images')
|
171 |
+
time.sleep(0.1)
|
172 |
+
bar()
|
173 |
+
|
174 |
+
|
175 |
+
def main(index=None, range_max=None, output=None, unzip=None, large=None):
|
176 |
+
"""
|
177 |
+
`main` is a function that takes in an index, a range_max, an output, and an
|
178 |
+
unzip, and if the user confirms that they have enough space, it downloads
|
179 |
+
the files from the index to the output, and if unzip is true, it unzips them
|
180 |
+
|
181 |
+
:param index: The index of the file you want to download
|
182 |
+
:param range_max: The number of files to download
|
183 |
+
:param output: The directory to download the files to
|
184 |
+
:param unzip: If you want to unzip the files after downloading them, set
|
185 |
+
this to True
|
186 |
+
:param large: If you want to download from DiffusionDB Large (14 million
|
187 |
+
images) instead of DiffusionDB 2M (2 million images)
|
188 |
+
:return: A list of files that have been downloaded
|
189 |
+
"""
|
190 |
+
if index and range_max:
|
191 |
+
if range_max - index >= 1999:
|
192 |
+
confirmation = input("Do you have at least 1.7Tb free: (y/n)")
|
193 |
+
if confirmation != "y":
|
194 |
+
return
|
195 |
+
files = download(index, range_max, output, large)
|
196 |
+
if unzip:
|
197 |
+
unzip_all(files)
|
198 |
+
elif index:
|
199 |
+
download(index, output=output, large=large)
|
200 |
+
else:
|
201 |
+
print("No index provided")
|
202 |
+
|
203 |
+
|
204 |
+
# This is a common pattern in Python. It allows you to run the main function of
|
205 |
+
# your script by running the script through the interpreter. It also allows you
|
206 |
+
# to import the script into the interpreter without automatically running the
|
207 |
+
# main function.
|
208 |
+
if __name__ == "__main__":
|
209 |
+
main(index, range_max, output, unzip, large)
|
download_dataset.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
from datasets import load_dataset
|
5 |
+
from PIL import Image
|
6 |
+
import shutil
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
def load_and_process():
|
10 |
+
dataset = load_dataset("poloclub/diffusiondb", split="train[:1000]")
|
11 |
+
|
12 |
+
os.makedirs("processed/images", exist_ok=True)
|
13 |
+
processed_data = []
|
14 |
+
|
15 |
+
for idx, sample in enumerate(tqdm(dataset)):
|
16 |
+
image_id = f"{idx:06d}.png"
|
17 |
+
|
18 |
+
if sample.get('image'):
|
19 |
+
sample['image'].save(f"processed/images/{image_id}")
|
20 |
+
|
21 |
+
data_entry = {
|
22 |
+
"id": idx,
|
23 |
+
"image_file": image_id,
|
24 |
+
"prompt": sample.get('p', ''),
|
25 |
+
"seed": sample.get('se', 0),
|
26 |
+
"cfg_scale": sample.get('c', 0.0),
|
27 |
+
"steps": sample.get('st', 0),
|
28 |
+
"sampler": sample.get('sa', '')
|
29 |
+
}
|
30 |
+
processed_data.append(data_entry)
|
31 |
+
|
32 |
+
return processed_data
|
33 |
+
|
34 |
+
def save_data(data):
|
35 |
+
with open("processed/data.json", "w") as f:
|
36 |
+
json.dump(data, f)
|
37 |
+
|
38 |
+
df = pd.DataFrame(data)
|
39 |
+
df.to_csv("processed/data.csv", index=False)
|
40 |
+
df.to_parquet("processed/data.parquet", index=False)
|
41 |
+
|
42 |
+
def main():
|
43 |
+
data = load_and_process()
|
44 |
+
save_data(data)
|
45 |
+
print(f"Processed {len(data)} samples")
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
preprocess_data.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from datasets import load_dataset
|
4 |
+
from PIL import Image, ImageOps, ImageFilter
|
5 |
+
from tqdm import tqdm
|
6 |
+
import random
|
7 |
+
|
8 |
+
def preprocess_image(image, target_size=512, quality_threshold=0.7):
|
9 |
+
"""Preprocess image with various enhancements"""
|
10 |
+
# Convert to RGB if needed
|
11 |
+
if image.mode != 'RGB':
|
12 |
+
image = image.convert('RGB')
|
13 |
+
|
14 |
+
# Filter out low quality images
|
15 |
+
width, height = image.size
|
16 |
+
if min(width, height) < target_size * quality_threshold:
|
17 |
+
return None
|
18 |
+
|
19 |
+
# Center crop to square if not already
|
20 |
+
if width != height:
|
21 |
+
size = min(width, height)
|
22 |
+
left = (width - size) // 2
|
23 |
+
top = (height - size) // 2
|
24 |
+
image = image.crop((left, top, left + size, top + size))
|
25 |
+
|
26 |
+
# Resize to target size
|
27 |
+
image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
28 |
+
|
29 |
+
# Enhance image quality
|
30 |
+
# Slightly sharpen
|
31 |
+
image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3))
|
32 |
+
|
33 |
+
# Auto-adjust levels
|
34 |
+
image = ImageOps.autocontrast(image, cutoff=1)
|
35 |
+
|
36 |
+
return image
|
37 |
+
|
38 |
+
def clean_prompt(prompt):
|
39 |
+
"""Clean and normalize prompts"""
|
40 |
+
if not prompt:
|
41 |
+
return ""
|
42 |
+
|
43 |
+
# Remove excessive whitespace
|
44 |
+
prompt = ' '.join(prompt.split())
|
45 |
+
|
46 |
+
# Remove common artifacts
|
47 |
+
prompt = prompt.replace(' ', ' ')
|
48 |
+
prompt = prompt.strip(' .,;:')
|
49 |
+
|
50 |
+
# Filter out very short or very long prompts
|
51 |
+
words = prompt.split()
|
52 |
+
if len(words) < 3 or len(words) > 50:
|
53 |
+
return None
|
54 |
+
|
55 |
+
return prompt
|
56 |
+
|
57 |
+
def prepare_dreambooth_data():
|
58 |
+
# Load dataset
|
59 |
+
dataset = load_dataset('poloclub/diffusiondb', 'large_random_1k')
|
60 |
+
train_data = dataset['train']
|
61 |
+
|
62 |
+
# Create directory structure
|
63 |
+
data_dir = "./diffusiondb_dataset"
|
64 |
+
os.makedirs(data_dir, exist_ok=True)
|
65 |
+
|
66 |
+
valid_samples = 0
|
67 |
+
|
68 |
+
# Process images with preprocessing
|
69 |
+
for idx, sample in enumerate(tqdm(train_data, desc="Processing images")):
|
70 |
+
# Preprocess image
|
71 |
+
image = preprocess_image(sample['image'])
|
72 |
+
if image is None:
|
73 |
+
continue
|
74 |
+
|
75 |
+
# Clean prompt
|
76 |
+
prompt = clean_prompt(sample.get('prompt', ''))
|
77 |
+
if prompt is None:
|
78 |
+
continue
|
79 |
+
|
80 |
+
# Save processed image
|
81 |
+
image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg")
|
82 |
+
image.save(image_path, "JPEG", quality=95, optimize=True)
|
83 |
+
|
84 |
+
# Save cleaned caption
|
85 |
+
caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt")
|
86 |
+
with open(caption_path, 'w', encoding='utf-8') as f:
|
87 |
+
f.write(prompt)
|
88 |
+
|
89 |
+
valid_samples += 1
|
90 |
+
|
91 |
+
print(f"Processed {len(train_data)} samples, saved {valid_samples} valid images to {data_dir}")
|
92 |
+
return data_dir
|
93 |
+
|
94 |
+
# Convert dataset
|
95 |
+
data_dir = prepare_dreambooth_data()
|
96 |
+
|
97 |
+
# Now you can use the standard accelerate command:
|
98 |
+
training_command = f"""
|
99 |
+
accelerate launch \\
|
100 |
+
--deepspeed_config_file ds_config.json \\
|
101 |
+
diffusers/examples/dreambooth/train_dreambooth.py \\
|
102 |
+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\
|
103 |
+
--instance_data_dir="{data_dir}" \\
|
104 |
+
--instance_prompt="a generated image" \\
|
105 |
+
--output_dir="./diffusiondb-model" \\
|
106 |
+
--resolution=512 \\
|
107 |
+
--train_batch_size=1 \\
|
108 |
+
--gradient_accumulation_steps=1 \\
|
109 |
+
--gradient_checkpointing \\
|
110 |
+
--learning_rate=5e-6 \\
|
111 |
+
--lr_scheduler="constant" \\
|
112 |
+
--lr_warmup_steps=0 \\
|
113 |
+
--max_train_steps=400 \\
|
114 |
+
--mixed_precision="fp16" \\
|
115 |
+
--checkpointing_steps=100 \\
|
116 |
+
--checkpoints_total_limit=1 \\
|
117 |
+
--report_to="tensorboard" \\
|
118 |
+
--logging_dir="./diffusiondb-model/logs"
|
119 |
+
"""
|
120 |
+
|
121 |
+
print("Run this command:")
|
122 |
+
print(training_command)
|
requirements.txt
CHANGED
@@ -13,4 +13,5 @@ faiss-cpu
|
|
13 |
sentence-transformers
|
14 |
trl[peft]
|
15 |
label-studio
|
16 |
-
datasets
|
|
|
|
13 |
sentence-transformers
|
14 |
trl[peft]
|
15 |
label-studio
|
16 |
+
datasets
|
17 |
+
alive_progress
|