|
from datasets import Dataset, load_from_disk, concatenate_datasets |
|
from diffusers import AutoencoderKL |
|
from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda |
|
from transformers import AutoModel, AutoImageProcessor, AutoTokenizer |
|
import torch |
|
import os |
|
import gc |
|
import numpy as np |
|
from PIL import Image |
|
from tqdm import tqdm |
|
import random |
|
import json |
|
import shutil |
|
import time |
|
from datetime import timedelta |
|
|
|
|
|
dtype = torch.float16 |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
batch_size = 20 |
|
min_size = 640 |
|
max_size = 1152 |
|
step = 64 |
|
img_share = 0.1 |
|
empty_share = 0.1 |
|
limit = 0 |
|
|
|
folder_path = "/workspace/all2" |
|
save_path = "/workspace/1152p2" |
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
def clear_cuda_memory(): |
|
if torch.cuda.is_available(): |
|
used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
|
print(f"used_gb: {used_gb:.2f} GB") |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
def load_models(): |
|
print("Загрузка моделей...") |
|
vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae", torch_dtype=dtype).to(device).eval() |
|
model = AutoModel.from_pretrained("visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True, optimized=True).to(device) |
|
processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip", use_fast=True) |
|
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip") |
|
return vae, model, processor, tokenizer |
|
|
|
vae, model, processor, tokenizer = load_models() |
|
|
|
|
|
def get_image_transform(min_size=256, max_size=512, step=64): |
|
def transform(img, dry_run=False): |
|
|
|
original_width, original_height = img.size |
|
|
|
|
|
if original_width >= original_height: |
|
new_width = max_size |
|
new_height = int(max_size * original_height / original_width) |
|
else: |
|
new_height = max_size |
|
new_width = int(max_size * original_width / original_height) |
|
|
|
if new_height < min_size or new_width < min_size: |
|
|
|
if original_width <= original_height: |
|
new_width = min_size |
|
new_height = int(min_size * original_height / original_width) |
|
else: |
|
new_height = min_size |
|
new_width = int(min_size * original_width / original_height) |
|
|
|
|
|
crop_width = min(max_size, (new_width // step) * step) |
|
crop_height = min(max_size, (new_height // step) * step) |
|
|
|
|
|
crop_width = max(min_size, crop_width) |
|
crop_height = max(min_size, crop_height) |
|
|
|
|
|
if dry_run: |
|
return crop_width, crop_height |
|
|
|
|
|
img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS) |
|
|
|
|
|
top = (new_height - crop_height) // 3 |
|
left = 0 |
|
|
|
|
|
img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height)) |
|
|
|
|
|
final_width, final_height = img_cropped.size |
|
|
|
|
|
img_tensor = ToTensor()(img_cropped) |
|
img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor) |
|
return img_tensor, img_cropped, final_width, final_height |
|
|
|
return transform |
|
|
|
|
|
def encode_images_batch(images, processor, model): |
|
pixel_values = torch.stack([processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0) for img in images]).to(device, dtype) |
|
|
|
with torch.inference_mode(): |
|
image_embeddings = model.vision_model(pixel_values).pooler_output |
|
|
|
return image_embeddings.unsqueeze(1).cpu().numpy() |
|
|
|
def encode_texts_batch(texts, tokenizer, model): |
|
with torch.inference_mode(): |
|
text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length", |
|
max_length=512, |
|
truncation=True).to(device) |
|
text_embeddings = model.encode_texts(text_tokenized.input_ids, text_tokenized.attention_mask) |
|
return text_embeddings.unsqueeze(1).cpu().numpy() |
|
|
|
def maybe_empty_label(label, prob=0.01): |
|
return "" if random.random() < prob else label |
|
|
|
def encode_to_latents(images, texts): |
|
transform = get_image_transform(min_size, max_size, step) |
|
try: |
|
|
|
latents_list = [] |
|
widths = [] |
|
heights = [] |
|
pil_images = [] |
|
|
|
for img in images: |
|
try: |
|
|
|
transformed_img, pil_img, final_width, final_height = transform(img) |
|
pil_images.append(pil_img) |
|
widths.append(final_width) |
|
heights.append(final_height) |
|
|
|
|
|
img_tensor = transformed_img.unsqueeze(0).to(device, dtype) |
|
with torch.no_grad(): |
|
posterior = vae.encode(img_tensor).latent_dist.mode() |
|
z = (posterior - vae.config.shift_factor) * vae.config.scaling_factor |
|
latents_list.append(z.cpu().numpy()) |
|
except Exception as e: |
|
print(f"Ошибка при кодировании VAE: {e}") |
|
continue |
|
|
|
latents = np.concatenate(latents_list, axis=0) |
|
|
|
|
|
if random.random() < img_share: |
|
|
|
embeddings = encode_images_batch(pil_images, processor, model) |
|
else: |
|
text_labels_with_empty = [maybe_empty_label(lbl, empty_share) for lbl in texts] |
|
embeddings = encode_texts_batch(text_labels_with_empty, tokenizer, model) |
|
|
|
return { |
|
"vae": latents, |
|
"embeddings": embeddings, |
|
"text": texts, |
|
"width": widths, |
|
"height": heights |
|
} |
|
except Exception as e: |
|
print(f"Критическая ошибка в encode_to_latents: {e}") |
|
raise |
|
|
|
|
|
def process_folder(folder_path, limit=None): |
|
""" |
|
Рекурсивно обходит указанную директорию и все вложенные директории, |
|
собирая пути к изображениям и соответствующим текстовым файлам. |
|
""" |
|
image_paths = [] |
|
text_paths = [] |
|
width = [] |
|
height = [] |
|
transform = get_image_transform(min_size, max_size, step) |
|
|
|
|
|
for root, dirs, files in os.walk(folder_path): |
|
for filename in files: |
|
|
|
if filename.lower().endswith((".jpg", ".jpeg", ".png")): |
|
image_path = os.path.join(root, filename) |
|
try: |
|
img = Image.open(image_path) |
|
except Exception as e: |
|
print(f"Ошибка при открытии {image_path}: {e}") |
|
os.remove(image_path) |
|
text_path = os.path.splitext(image_path)[0] + ".txt" |
|
if os.path.exists(text_path): |
|
os.remove(text_path) |
|
continue |
|
|
|
w, h = transform(img, dry_run=True) |
|
|
|
text_path = os.path.splitext(image_path)[0] + ".txt" |
|
|
|
|
|
if os.path.exists(text_path) and min(w, h)>0: |
|
image_paths.append(image_path) |
|
text_paths.append(text_path) |
|
width.append(w) |
|
height.append(h) |
|
|
|
|
|
if limit and limit>0 and len(image_paths) >= limit: |
|
print(f"Достигнут лимит в {limit} изображений") |
|
return image_paths, text_paths, width, height |
|
|
|
print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями") |
|
return image_paths, text_paths, width, height |
|
|
|
def process_in_chunks(image_paths, text_paths, width, height, chunk_size=50000, batch_size=1): |
|
total_files = len(image_paths) |
|
start_time = time.time() |
|
chunks = range(0, total_files, chunk_size) |
|
|
|
for chunk_idx, start in enumerate(chunks, 1): |
|
end = min(start + chunk_size, total_files) |
|
chunk_image_paths = image_paths[start:end] |
|
chunk_text_paths = text_paths[start:end] |
|
chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths) |
|
chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths) |
|
|
|
|
|
chunk_texts = [] |
|
for text_path in chunk_text_paths: |
|
try: |
|
with open(text_path, 'r', encoding='utf-8') as f: |
|
text = f.read().strip() |
|
chunk_texts.append(text) |
|
except Exception as e: |
|
print(f"Ошибка чтения {text_path}: {e}") |
|
chunk_texts.append("") |
|
|
|
|
|
size_groups = {} |
|
for i in range(len(chunk_image_paths)): |
|
size_key = (chunk_widths[i], chunk_heights[i]) |
|
if size_key not in size_groups: |
|
size_groups[size_key] = {"image_paths": [], "texts": []} |
|
size_groups[size_key]["image_paths"].append(chunk_image_paths[i]) |
|
size_groups[size_key]["texts"].append(chunk_texts[i]) |
|
|
|
|
|
for size_key, group_data in size_groups.items(): |
|
print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений") |
|
|
|
group_dataset = Dataset.from_dict({ |
|
"image_path": group_data["image_paths"], |
|
"text": group_data["texts"] |
|
}) |
|
|
|
|
|
processed_group = group_dataset.map( |
|
lambda examples: encode_to_latents( |
|
[Image.open(path) for path in examples["image_path"]], |
|
examples["text"] |
|
), |
|
batched=True, |
|
batch_size=batch_size, |
|
remove_columns=["image_path"], |
|
desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}" |
|
) |
|
|
|
|
|
group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}" |
|
processed_group.save_to_disk(group_save_path) |
|
clear_cuda_memory() |
|
elapsed = time.time() - start_time |
|
processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]]) |
|
if processed > 0: |
|
remaining = (elapsed / processed) * (total_files - processed) |
|
elapsed_str = str(timedelta(seconds=int(elapsed))) |
|
remaining_str = str(timedelta(seconds=int(remaining))) |
|
print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})") |
|
|
|
|
|
def combine_chunks(temp_path, final_path): |
|
"""Объединение обработанных чанков в финальный датасет""" |
|
chunks = sorted([ |
|
os.path.join(temp_path, d) |
|
for d in os.listdir(temp_path) |
|
if d.startswith("chunk_") |
|
]) |
|
|
|
datasets = [load_from_disk(chunk) for chunk in chunks] |
|
combined = concatenate_datasets(datasets) |
|
combined.save_to_disk(final_path) |
|
|
|
print(f"✅ Датасет успешно сохранен в: {final_path}") |
|
|
|
|
|
|
|
temp_path = f"{save_path}_temp" |
|
os.makedirs(temp_path, exist_ok=True) |
|
|
|
|
|
image_paths, text_paths, width, height = process_folder(folder_path,limit) |
|
print(f"Всего найдено {len(image_paths)} изображений") |
|
|
|
|
|
process_in_chunks(image_paths, text_paths, width, height, chunk_size=50000, batch_size=batch_size) |
|
|
|
|
|
combine_chunks(temp_path, save_path) |
|
|
|
|
|
try: |
|
shutil.rmtree(temp_path) |
|
print(f"✅ Временная папка {temp_path} успешно удалена") |
|
except Exception as e: |
|
print(f"⚠️ Ошибка при удалении временной папки: {e}") |