LoRa_Streamlit / train.py
ramimu's picture
Update train.py
aff7e63 verified
raw
history blame
2.2 kB
import os
import torch
from huggingface_hub import snapshot_download
from diffusers import (
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
AutoencoderKL,
UNet2DConditionModel
)
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model
MODEL_ID = "black-forest-labs/FLUX.1-dev"
# download
model_path = snapshot_download(
MODEL_ID,
local_dir="./fluxdev-model",
use_auth_token=True
)
# later loading
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
use_auth_token=True
).to("cuda")
# 1) grab the model locally
print("📥 Downloading Flux‑Dev model…")
model_path = snapshot_download(MODEL_ID, local_dir="./fluxdev-model")
# 2) load each piece with its correct subfolder
print("🔄 Loading scheduler…")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
print("🔄 Loading VAE…")
vae = AutoencoderKL.from_pretrained(
model_path, subfolder="vae", torch_dtype=torch.float16
)
print("🔄 Loading text encoder + tokenizer…")
text_encoder = CLIPTextModel.from_pretrained(
model_path, subfolder="text_encoder", torch_dtype=torch.float16
)
tokenizer = CLIPTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
print("🔄 Loading U‑Net…")
unet = UNet2DConditionModel.from_pretrained(
model_path, subfolder="unet", torch_dtype=torch.float16
)
# 3) assemble the pipeline
print("🛠 Assembling pipeline…")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler
).to("cuda")
# 4) apply LoRA
print("🧠 Applying LoRA…")
lora_config = LoraConfig(r=16, lora_alpha=16, bias="none", task_type="CAUSAL_LM")
pipe.unet = get_peft_model(pipe.unet, lora_config)
# 5) your training loop (or dummy loop for illustration)
print("🚀 Starting fine‑tuning…")
for step in range(100):
print(f"Training step {step+1}/100")
# …insert your actual data‑loader and loss/backprop here…
os.makedirs(output_dir, exist_ok=True)
pipe.save_pretrained(output_dir)
print("✅ Done. LoRA weights in", output_dir)