goodmodeler commited on
Commit
ee07cd3
·
1 Parent(s): 134ce68

ADD: fix train_model

Browse files
Files changed (1) hide show
  1. train_model.py +172 -56
train_model.py CHANGED
@@ -4,41 +4,75 @@ from datasets import load_dataset
4
  from PIL import Image, ImageOps, ImageFilter
5
  from tqdm import tqdm
6
  import random
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def preprocess_image(image, target_size=512, quality_threshold=0.7):
9
  """Preprocess image with various enhancements"""
10
- # Convert to RGB if needed
11
- if image.mode != 'RGB':
12
- image = image.convert('RGB')
13
-
14
- # Filter out low quality images
15
- width, height = image.size
16
- if min(width, height) < target_size * quality_threshold:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  return None
18
-
19
- # Center crop to square if not already
20
- if width != height:
21
- size = min(width, height)
22
- left = (width - size) // 2
23
- top = (height - size) // 2
24
- image = image.crop((left, top, left + size, top + size))
25
-
26
- # Resize to target size
27
- image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
28
-
29
- # Enhance image quality
30
- # Slightly sharpen
31
- image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3))
32
-
33
- # Auto-adjust levels
34
- image = ImageOps.autocontrast(image, cutoff=1)
35
-
36
- return image
37
 
38
  def clean_prompt(prompt):
39
  """Clean and normalize prompts"""
40
  if not prompt:
41
- return ""
42
 
43
  # Remove excessive whitespace
44
  prompt = ' '.join(prompt.split())
@@ -56,53 +90,131 @@ def clean_prompt(prompt):
56
 
57
  def prepare_dreambooth_data():
58
  # Load dataset
59
- dataset = load_dataset("laion/laion2B-en-aesthetic")
60
- train_data = dataset['train']
61
 
62
  # Create directory structure
63
- data_dir = "./diffusiondb_dataset"
64
  os.makedirs(data_dir, exist_ok=True)
65
 
66
  valid_samples = 0
 
 
 
 
67
 
68
  # Process images with preprocessing
69
- for idx, sample in enumerate(tqdm(train_data, desc="Processing images")):
70
- # Preprocess image
71
- image = preprocess_image(sample['image'])
72
- if image is None:
73
- continue
74
 
75
- # Clean prompt
76
- prompt = clean_prompt(sample.get('prompt', ''))
77
- if prompt is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Save processed image
81
- image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg")
82
- image.save(image_path, "JPEG", quality=95, optimize=True)
83
 
84
- # Save cleaned caption
85
- caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt")
 
 
 
86
  with open(caption_path, 'w', encoding='utf-8') as f:
87
  f.write(prompt)
88
-
89
- valid_samples += 1
90
 
91
- print(f"Processed {len(train_data)} samples, saved {valid_samples} valid images to {data_dir}")
92
  return data_dir
93
 
94
- # Convert dataset
95
- data_dir = prepare_dreambooth_data()
96
-
97
- # Now you can use the standard accelerate command:
98
- training_command = f"""
 
99
  accelerate launch \\
100
  --deepspeed_config_file ds_config.json \\
101
  diffusers/examples/dreambooth/train_dreambooth.py \\
102
  --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\
103
  --instance_data_dir="{data_dir}" \\
104
- --instance_prompt="a generated image" \\
105
- --output_dir="./diffusiondb-model" \\
106
  --resolution=512 \\
107
  --train_batch_size=1 \\
108
  --gradient_accumulation_steps=1 \\
@@ -115,8 +227,12 @@ accelerate launch \\
115
  --checkpointing_steps=100 \\
116
  --checkpoints_total_limit=1 \\
117
  --report_to="tensorboard" \\
118
- --logging_dir="./diffusiondb-model/logs"
119
  """
 
 
 
 
120
 
121
- print("Run this command:")
122
- print(training_command)
 
4
  from PIL import Image, ImageOps, ImageFilter
5
  from tqdm import tqdm
6
  import random
7
+ import requests
8
+ import io
9
+ import time
10
+
11
+ def download_image(url, timeout=10, retries=2):
12
+ """Download image from URL with retry mechanism"""
13
+ for attempt in range(retries):
14
+ try:
15
+ headers = {
16
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
17
+ }
18
+ response = requests.get(url, timeout=timeout, headers=headers)
19
+
20
+ if response.status_code == 200:
21
+ image = Image.open(io.BytesIO(response.content))
22
+ return image
23
+ else:
24
+ return None
25
+
26
+ except Exception as e:
27
+ if attempt == retries - 1: # Last attempt
28
+ print(f"Failed to download {url}: {e}")
29
+ return None
30
+ time.sleep(0.5) # Brief pause before retry
31
+
32
+ return None
33
 
34
  def preprocess_image(image, target_size=512, quality_threshold=0.7):
35
  """Preprocess image with various enhancements"""
36
+ if image is None:
37
+ return None
38
+
39
+ try:
40
+ # Convert to RGB if needed
41
+ if image.mode != 'RGB':
42
+ image = image.convert('RGB')
43
+
44
+ # Filter out low quality images
45
+ width, height = image.size
46
+ if min(width, height) < target_size * quality_threshold:
47
+ return None
48
+
49
+ # Center crop to square if not already
50
+ if width != height:
51
+ size = min(width, height)
52
+ left = (width - size) // 2
53
+ top = (height - size) // 2
54
+ image = image.crop((left, top, left + size, top + size))
55
+
56
+ # Resize to target size
57
+ image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
58
+
59
+ # Enhance image quality
60
+ # Slightly sharpen
61
+ image = image.filter(ImageFilter.UnsharpMask(radius=0.5, percent=120, threshold=3))
62
+
63
+ # Auto-adjust levels
64
+ image = ImageOps.autocontrast(image, cutoff=1)
65
+
66
+ return image
67
+
68
+ except Exception as e:
69
+ print(f"Error preprocessing image: {e}")
70
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def clean_prompt(prompt):
73
  """Clean and normalize prompts"""
74
  if not prompt:
75
+ return None
76
 
77
  # Remove excessive whitespace
78
  prompt = ' '.join(prompt.split())
 
90
 
91
  def prepare_dreambooth_data():
92
  # Load dataset
93
+ print("Loading LAION dataset...")
94
+ dataset = load_dataset("laion/laion2B-en-aesthetic", split="train", streaming=True)
95
 
96
  # Create directory structure
97
+ data_dir = "./laion_dataset"
98
  os.makedirs(data_dir, exist_ok=True)
99
 
100
  valid_samples = 0
101
+ processed_count = 0
102
+ max_samples = 1000 # Limit total samples to process
103
+
104
+ print(f"Starting to process up to {max_samples} samples...")
105
 
106
  # Process images with preprocessing
107
+ for idx, sample in enumerate(tqdm(dataset, desc="Processing LAION samples")):
108
+ if processed_count >= max_samples:
109
+ break
 
 
110
 
111
+ processed_count += 1
112
+
113
+ try:
114
+ # Get URL and text from LAION format
115
+ image_url = sample.get('URL', '')
116
+ text_prompt = sample.get('TEXT', '')
117
+
118
+ if not image_url or not text_prompt:
119
+ continue
120
+
121
+ # Clean prompt first
122
+ prompt = clean_prompt(text_prompt)
123
+ if prompt is None:
124
+ continue
125
+
126
+ # Download image from URL
127
+ print(f"Downloading image {valid_samples + 1}: {image_url[:50]}...")
128
+ image = download_image(image_url)
129
+ if image is None:
130
+ continue
131
+
132
+ # Preprocess downloaded image
133
+ processed_image = preprocess_image(image)
134
+ if processed_image is None:
135
+ continue
136
+
137
+ # Save processed image
138
+ image_path = os.path.join(data_dir, f"image_{valid_samples:04d}.jpg")
139
+ processed_image.save(image_path, "JPEG", quality=95, optimize=True)
140
+
141
+ # Save cleaned caption
142
+ caption_path = os.path.join(data_dir, f"image_{valid_samples:04d}.txt")
143
+ with open(caption_path, 'w', encoding='utf-8') as f:
144
+ f.write(prompt)
145
+
146
+ valid_samples += 1
147
+
148
+ # Optional: Add metadata file
149
+ metadata_path = os.path.join(data_dir, f"image_{valid_samples-1:04d}_meta.txt")
150
+ with open(metadata_path, 'w', encoding='utf-8') as f:
151
+ f.write(f"URL: {image_url}\n")
152
+ f.write(f"Aesthetic: {sample.get('aesthetic', 'N/A')}\n")
153
+ f.write(f"Width: {sample.get('WIDTH', 'N/A')}\n")
154
+ f.write(f"Height: {sample.get('HEIGHT', 'N/A')}\n")
155
+
156
+ # Stop if we have enough samples
157
+ if valid_samples >= 100: # Adjust this number as needed
158
+ break
159
+
160
+ except Exception as e:
161
+ print(f"Error processing sample {idx}: {e}")
162
  continue
163
+
164
+ print(f"Processed {processed_count} samples, saved {valid_samples} valid images to {data_dir}")
165
+ return data_dir
166
+
167
+ def create_demo_dataset():
168
+ """Create demo dataset as last resort"""
169
+ print("Creating demo dataset...")
170
+
171
+ data_dir = "./demo_dataset"
172
+ os.makedirs(data_dir, exist_ok=True)
173
+
174
+ demo_prompts = [
175
+ "a beautiful landscape with mountains",
176
+ "portrait of a person with detailed features",
177
+ "abstract colorful digital artwork",
178
+ "modern architecture building design",
179
+ "natural forest scene with trees",
180
+ "urban cityscape at sunset",
181
+ "artistic oil painting style",
182
+ "vintage photography aesthetic",
183
+ "minimalist geometric composition",
184
+ "vibrant surreal art piece"
185
+ ]
186
+
187
+ for idx, prompt in enumerate(demo_prompts):
188
+ # Create gradient background
189
+ color1 = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200))
190
+ color2 = (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
191
 
192
+ image = Image.new('RGB', (512, 512), color1)
 
 
193
 
194
+ # Save files
195
+ image_path = os.path.join(data_dir, f"image_{idx:04d}.jpg")
196
+ image.save(image_path, "JPEG", quality=95)
197
+
198
+ caption_path = os.path.join(data_dir, f"image_{idx:04d}.txt")
199
  with open(caption_path, 'w', encoding='utf-8') as f:
200
  f.write(prompt)
 
 
201
 
202
+ print(f"Created {len(demo_prompts)} demo samples")
203
  return data_dir
204
 
205
+ # Main execution with fallback
206
+ def main():
207
+ data_dir = prepare_dreambooth_data()
208
+
209
+ # Generate training command
210
+ training_command = f"""
211
  accelerate launch \\
212
  --deepspeed_config_file ds_config.json \\
213
  diffusers/examples/dreambooth/train_dreambooth.py \\
214
  --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \\
215
  --instance_data_dir="{data_dir}" \\
216
+ --instance_prompt="a high quality image" \\
217
+ --output_dir="./laion-model" \\
218
  --resolution=512 \\
219
  --train_batch_size=1 \\
220
  --gradient_accumulation_steps=1 \\
 
227
  --checkpointing_steps=100 \\
228
  --checkpoints_total_limit=1 \\
229
  --report_to="tensorboard" \\
230
+ --logging_dir="./laion-model/logs"
231
  """
232
+
233
+ print(f"\n✅ Dataset prepared in: {data_dir}")
234
+ print("🚀 Run this command to train:")
235
+ print(training_command)
236
 
237
+ if __name__ == "__main__":
238
+ main()