Spaces:
Paused
Paused
# train.py | |
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from peft import LoraConfig, get_peft_model | |
# 1οΈβ£ Pick your scheduler class | |
from diffusers import ( | |
StableDiffusionPipeline, | |
DPMSolverMultistepScheduler, | |
UNet2DConditionModel, | |
AutoencoderKL, | |
) | |
from transformers import CLIPTextModel, CLIPTokenizer | |
# βββ 1) CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
DATA_DIR = os.getenv("DATA_DIR", "./data") | |
MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model") | |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained") | |
# βββ 2) DOWNLOAD OR VERIFY BASE MODEL ββββββββββββββββββββββββββββββββββββββββββ | |
if not os.path.isdir(MODEL_DIR): | |
MODEL_DIR = snapshot_download( | |
repo_id="HiDream-ai/HiDream-I1-Dev", | |
local_dir=MODEL_DIR | |
) | |
# βββ 3) LOAD EACH PIPELINE COMPONENT ββββββββββββββββββββββββββββββββββββββββββ | |
# 3a) Scheduler | |
scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
MODEL_DIR, | |
subfolder="scheduler" | |
) | |
# 3b) VAE | |
vae = AutoencoderKL.from_pretrained( | |
MODEL_DIR, | |
subfolder="vae", | |
torch_dtype=torch.float16 | |
).to("cuda") | |
# 3c) Text encoder + tokenizer | |
text_encoder = CLIPTextModel.from_pretrained( | |
MODEL_DIR, | |
subfolder="text_encoder", | |
torch_dtype=torch.float16 | |
).to("cuda") | |
tokenizer = CLIPTokenizer.from_pretrained( | |
MODEL_DIR, | |
subfolder="tokenizer" | |
) | |
# 3d) UβNet | |
unet = UNet2DConditionModel.from_pretrained( | |
MODEL_DIR, | |
subfolder="unet", | |
torch_dtype=torch.float16 | |
).to("cuda") | |
# βββ 4) BUILD THE PIPELINE ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
pipe = StableDiffusionPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
).to("cuda") | |
# βββ 5) APPLY LORA ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=16, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
pipe.unet = get_peft_model(pipe.unet, lora_config) | |
# βββ 6) TRAINING LOOP (SIMULATED) βββββββββββββββββββββββββββββββββββββββββββββ | |
print(f"π Data at {DATA_DIR}") | |
for step in range(100): | |
# β¦ your real data loading + optimizer here β¦ | |
print(f"Training step {step+1}/100") | |
# βββ 7) SAVE THE FINEβTUNED LOβRA βββββββββββββββββββββββββββββββββββββββββββββ | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
pipe.save_pretrained(OUTPUT_DIR) | |
print("β Done! Saved to", OUTPUT_DIR) | |