Spaces:
Paused
Paused
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from diffusers import ( | |
StableDiffusionPipeline, | |
DPMSolverMultistepScheduler, | |
AutoencoderKL, | |
UNet2DConditionModel | |
) | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from peft import LoraConfig, get_peft_model | |
MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
# download | |
model_path = snapshot_download( | |
MODEL_ID, | |
local_dir="./fluxdev-model", | |
use_auth_token=True | |
) | |
# later loading | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, | |
use_auth_token=True | |
).to("cuda") | |
# 1) grab the model locally | |
print("📥 Downloading Flux‑Dev model…") | |
model_path = snapshot_download(MODEL_ID, local_dir="./fluxdev-model") | |
# 2) load each piece with its correct subfolder | |
print("🔄 Loading scheduler…") | |
scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
model_path, subfolder="scheduler" | |
) | |
print("🔄 Loading VAE…") | |
vae = AutoencoderKL.from_pretrained( | |
model_path, subfolder="vae", torch_dtype=torch.float16 | |
) | |
print("🔄 Loading text encoder + tokenizer…") | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_path, subfolder="text_encoder", torch_dtype=torch.float16 | |
) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
model_path, subfolder="tokenizer" | |
) | |
print("🔄 Loading U‑Net…") | |
unet = UNet2DConditionModel.from_pretrained( | |
model_path, subfolder="unet", torch_dtype=torch.float16 | |
) | |
# 3) assemble the pipeline | |
print("🛠 Assembling pipeline…") | |
pipe = StableDiffusionPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler | |
).to("cuda") | |
# 4) apply LoRA | |
print("🧠 Applying LoRA…") | |
lora_config = LoraConfig(r=16, lora_alpha=16, bias="none", task_type="CAUSAL_LM") | |
pipe.unet = get_peft_model(pipe.unet, lora_config) | |
# 5) your training loop (or dummy loop for illustration) | |
print("🚀 Starting fine‑tuning…") | |
for step in range(100): | |
print(f"Training step {step+1}/100") | |
# …insert your actual data‑loader and loss/backprop here… | |
os.makedirs(output_dir, exist_ok=True) | |
pipe.save_pretrained(output_dir) | |
print("✅ Done. LoRA weights in", output_dir) | |