''' ostris/ai-toolkit on https://modal.com This module provides the Modal app and main function for training FLUX LoRA models. The main() function is meant to be called from hf_ui.py, not run directly. ''' import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import sys import modal from dotenv import load_dotenv # Load the .env file if it exists load_dotenv() import yaml import traceback import zipfile sys.path.insert(0, "/root/ai-toolkit") # must come before ANY torch or fastai imports # import toolkit.cuda_malloc # turn off diffusers telemetry until I can figure out how to make it opt-in os.environ['DISABLE_TELEMETRY'] = 'YES' # Khai báo secrets hf_secret = modal.Secret.from_name("huggingface-secret") # define the volume for storing model outputs, using "creating volumes lazily": https://modal.com/docs/guide/volumes # you will find your model, samples and optimizer stored in: https://modal.com/storage/your-username/main/flux-lora-models model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True) # modal_output, due to "cannot mount volume on non-empty path" requirement MOUNT_DIR = "/root/ai-toolkit/modal_output" # modal_output, due to "cannot mount volume on non-empty path" requirement # define modal app image = ( modal.Image.debian_slim(python_version="3.11") # install required system and pip packages, more about this modal approach: https://modal.com/docs/examples/dreambooth_app .apt_install("libgl1", "libglib2.0-0") .pip_install( "python-dotenv", "torch", "diffusers[torch]", "transformers", "ftfy", "torchvision", "oyaml", "opencv-python", "albumentations", "safetensors", "lycoris-lora==1.8.3", "flatten_json", "pyyaml", "tensorboard", "kornia", "invisible-watermark", "einops", "accelerate", "toml", "pydantic", "omegaconf", "k-diffusion", "open_clip_torch", "timm", "prodigyopt", "controlnet_aux==0.0.7", "bitsandbytes", "hf_transfer", "lpips", "pytorch_fid", "optimum-quanto", "sentencepiece", "huggingface_hub", "peft", "wandb", ) ) # Mount từ thư mục gốc của HF Space code_mount = modal.Mount.from_local_dir( local_path="/home/user/app", remote_path="/root/ai-toolkit" ) # create the Modal app with the necessary mounts and volumes app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume}) # Check if we have DEBUG_TOOLKIT in env if os.environ.get("DEBUG_TOOLKIT", "0") == "1": # Set torch to trace mode import torch torch.autograd.set_detect_anomaly(True) import argparse from toolkit.job import get_job def print_end_message(jobs_completed, jobs_failed): failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" print("") print("========================================") print("Result:") if len(completed_string) > 0: print(f" - {completed_string}") if len(failure_string) > 0: print(f" - {failure_string}") print("========================================") @app.function( # request a GPU with at least 24GB VRAM # more about modal GPU's: https://modal.com/docs/guide/gpu gpu="A100", # gpu="H100" # more about modal timeouts: https://modal.com/docs/guide/timeouts timeout=7200, # 2 hours, increase or decrease if needed secrets=[hf_secret] ) def main(config_file_list_str: str, recover: bool = False, name: str = None): # Các secrets sẽ tự động được inject vào environment variables # os.environ["HF_TOKEN"] và os.environ["WANDB_API_KEY"] # convert the config file list from a string to a list # config_file_list = config_file_list_str.split(",") # convert the config string into a usable dict config = None try: config = yaml.safe_load(config_file_list_str) except Exception as e: print(f"Error loading config file: {e}") traceback.print_exc() raise e jobs_completed = 0 jobs_failed = 0 print(f"Running {config['config']['name']}") try: # 1. validate config file to make sure required keys are present if 'config' not in config: raise ValueError("config file must have a `config` section") if 'process' not in config['config']: raise ValueError("config file must have a `process` section") if len(config['config']['process']) == 0: raise ValueError("config file must have at least one process") if 'type' not in config['config']['process'][0]: raise ValueError("config file process must have a `type`") if 'training_folder' not in config['config']['process'][0]: raise ValueError("config file process must have a `training_folder`") if not config['config']['process'][0]['training_folder'].startswith("/root/ai-toolkit"): raise ValueError("config file process training_folder path must start with /root/ai-toolkit") # find a dataset inside process object datasets = config['config']['process'][0].get('datasets', None) if datasets is not None and isinstance(datasets, list): for dataset in datasets: if 'folder_path' in dataset: if not dataset['folder_path'].startswith('/root/ai-toolkit'): raise ValueError("config file process dataset folder_path must start with /root/ai-toolkit") job = get_job(config, name) job.config['process'][0]['training_folder'] = MOUNT_DIR os.makedirs(MOUNT_DIR, exist_ok=True) print(f"Training outputs will be saved to: {MOUNT_DIR}") # handle dataset zip datasets = config['config']['process'][0].get('datasets', None) if datasets is not None and isinstance(datasets, list): for dataset in datasets: dataset_path = dataset.get('folder_path', None) if dataset_path is not None: # Kiểm tra xem trong folder có zip file không for file in os.listdir(dataset_path): if file.lower().endswith('.zip'): zip_path = os.path.join(dataset_path, file) # Tạo subfolder để extract extract_path = os.path.join(dataset_path, 'extracted') os.makedirs(extract_path, exist_ok=True) print(f"Extracting dataset zip file: {zip_path}") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_path) # Cập nhật đường dẫn dataset trong config dataset['folder_path'] = extract_path # Xóa zip file sau khi extract os.remove(zip_path) print(f"Dataset extracted to: {extract_path}") break # run the job job.run() # commit the volume after training model_volume.commit() job.cleanup() jobs_completed += 1 except Exception as e: print(f"Error running job: {e}") if 'response' in e.__dict__: print(f" - Response code: {e.response.status_code} text: {e.response.text}") jobs_failed += 1 traceback.print_exc() if not recover: print_end_message(jobs_completed, jobs_failed) raise e print_end_message(jobs_completed, jobs_failed)