|
''' |
|
ostris/ai-toolkit on https://modal.com |
|
This module provides the Modal app and main function for training FLUX LoRA models. |
|
The main() function is meant to be called from hf_ui.py, not run directly. |
|
''' |
|
|
|
import os |
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
import sys |
|
import modal |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
import yaml |
|
import traceback |
|
import zipfile |
|
|
|
sys.path.insert(0, "/root/ai-toolkit") |
|
|
|
|
|
|
|
|
|
os.environ['DISABLE_TELEMETRY'] = 'YES' |
|
|
|
hf_secret = modal.Secret.from_name("huggingface-secret") |
|
|
|
|
|
|
|
model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True) |
|
|
|
|
|
MOUNT_DIR = "/root/ai-toolkit/modal_output" |
|
|
|
|
|
image = ( |
|
modal.Image.debian_slim(python_version="3.11") |
|
|
|
.apt_install("libgl1", "libglib2.0-0") |
|
.pip_install( |
|
"python-dotenv", |
|
"torch", |
|
"diffusers[torch]", |
|
"transformers", |
|
"ftfy", |
|
"torchvision", |
|
"oyaml", |
|
"opencv-python", |
|
"albumentations", |
|
"safetensors", |
|
"lycoris-lora==1.8.3", |
|
"flatten_json", |
|
"pyyaml", |
|
"tensorboard", |
|
"kornia", |
|
"invisible-watermark", |
|
"einops", |
|
"accelerate", |
|
"toml", |
|
"pydantic", |
|
"omegaconf", |
|
"k-diffusion", |
|
"open_clip_torch", |
|
"timm", |
|
"prodigyopt", |
|
"controlnet_aux==0.0.7", |
|
"bitsandbytes", |
|
"hf_transfer", |
|
"lpips", |
|
"pytorch_fid", |
|
"optimum-quanto", |
|
"sentencepiece", |
|
"huggingface_hub", |
|
"peft", |
|
"wandb", |
|
) |
|
) |
|
|
|
|
|
code_mount = modal.Mount.from_local_dir( |
|
local_path="/home/user/app", |
|
remote_path="/root/ai-toolkit" |
|
) |
|
|
|
|
|
app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume}) |
|
|
|
|
|
if os.environ.get("DEBUG_TOOLKIT", "0") == "1": |
|
|
|
import torch |
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
import argparse |
|
from toolkit.job import get_job |
|
|
|
def print_end_message(jobs_completed, jobs_failed): |
|
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" |
|
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" |
|
|
|
print("") |
|
print("========================================") |
|
print("Result:") |
|
if len(completed_string) > 0: |
|
print(f" - {completed_string}") |
|
if len(failure_string) > 0: |
|
print(f" - {failure_string}") |
|
print("========================================") |
|
|
|
|
|
@app.function( |
|
|
|
|
|
gpu="A100", |
|
|
|
timeout=7200, |
|
secrets=[hf_secret] |
|
) |
|
def main(config_file_list_str: str, recover: bool = False, name: str = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
config = None |
|
try: |
|
config = yaml.safe_load(config_file_list_str) |
|
except Exception as e: |
|
print(f"Error loading config file: {e}") |
|
traceback.print_exc() |
|
raise e |
|
|
|
jobs_completed = 0 |
|
jobs_failed = 0 |
|
|
|
print(f"Running {config['config']['name']}") |
|
|
|
try: |
|
|
|
if 'config' not in config: |
|
raise ValueError("config file must have a `config` section") |
|
if 'process' not in config['config']: |
|
raise ValueError("config file must have a `process` section") |
|
if len(config['config']['process']) == 0: |
|
raise ValueError("config file must have at least one process") |
|
if 'type' not in config['config']['process'][0]: |
|
raise ValueError("config file process must have a `type`") |
|
if 'training_folder' not in config['config']['process'][0]: |
|
raise ValueError("config file process must have a `training_folder`") |
|
if not config['config']['process'][0]['training_folder'].startswith("/root/ai-toolkit"): |
|
raise ValueError("config file process training_folder path must start with /root/ai-toolkit") |
|
|
|
|
|
datasets = config['config']['process'][0].get('datasets', None) |
|
if datasets is not None and isinstance(datasets, list): |
|
for dataset in datasets: |
|
if 'folder_path' in dataset: |
|
if not dataset['folder_path'].startswith('/root/ai-toolkit'): |
|
raise ValueError("config file process dataset folder_path must start with /root/ai-toolkit") |
|
|
|
job = get_job(config, name) |
|
|
|
job.config['process'][0]['training_folder'] = MOUNT_DIR |
|
os.makedirs(MOUNT_DIR, exist_ok=True) |
|
print(f"Training outputs will be saved to: {MOUNT_DIR}") |
|
|
|
|
|
datasets = config['config']['process'][0].get('datasets', None) |
|
if datasets is not None and isinstance(datasets, list): |
|
for dataset in datasets: |
|
dataset_path = dataset.get('folder_path', None) |
|
if dataset_path is not None: |
|
|
|
for file in os.listdir(dataset_path): |
|
if file.lower().endswith('.zip'): |
|
zip_path = os.path.join(dataset_path, file) |
|
|
|
extract_path = os.path.join(dataset_path, 'extracted') |
|
os.makedirs(extract_path, exist_ok=True) |
|
|
|
print(f"Extracting dataset zip file: {zip_path}") |
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
zip_ref.extractall(extract_path) |
|
|
|
|
|
dataset['folder_path'] = extract_path |
|
|
|
os.remove(zip_path) |
|
print(f"Dataset extracted to: {extract_path}") |
|
break |
|
|
|
|
|
job.run() |
|
|
|
|
|
model_volume.commit() |
|
|
|
job.cleanup() |
|
jobs_completed += 1 |
|
|
|
except Exception as e: |
|
print(f"Error running job: {e}") |
|
if 'response' in e.__dict__: |
|
print(f" - Response code: {e.response.status_code} text: {e.response.text}") |
|
jobs_failed += 1 |
|
traceback.print_exc() |
|
if not recover: |
|
print_end_message(jobs_completed, jobs_failed) |
|
raise e |
|
|
|
print_end_message(jobs_completed, jobs_failed) |