Spaces:
Running
Running
import os | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HF_DATASET = os.environ.get("HF_DATASET") | |
#HF_DATASET = "DevWild/autotrain-pr0b0rk" | |
repo_id = os.environ.get("MODEL_REPO_ID") | |
from huggingface_hub import snapshot_download, delete_repo, metadata_update | |
import uuid | |
import json | |
import yaml | |
import subprocess | |
import sys | |
from typing import Optional | |
#from huggingface_hub import login | |
#HF_TOKEN = os.getenv("HF_TOKEN") | |
#if HF_TOKEN: | |
# login(token=HF_TOKEN) | |
#else: | |
# raise ValueError("HF_TOKEN environment variable not found!") | |
if not HF_TOKEN: | |
raise ValueError("Missing HF_TOKEN") | |
if not HF_DATASET: | |
raise ValueError("Missing HF_DATASET") | |
if not repo_id: | |
raise ValueError("Missing MODEL_REPO_ID") | |
# Prevent running script.py twice | |
LOCKFILE = "/tmp/.script_lock" | |
if os.path.exists(LOCKFILE): | |
print("π Script already ran once β skipping.") | |
exit(0) | |
with open(LOCKFILE, "w") as f: | |
f.write("lock") | |
print("π Running script for the first time") | |
# START logging | |
print("π ENV DEBUG START") | |
print("HF_TOKEN present?", bool(HF_TOKEN)) | |
print("HF_DATASET:", HF_DATASET) | |
print("MODEL_REPO_ID:", repo_id) | |
print("π ENV DEBUG END") | |
#dataset_dir = snapshot_download(HF_DATASET, token=HF_TOKEN) | |
def download_dataset(hf_dataset_path: str): | |
random_id = str(uuid.uuid4()) | |
snapshot_download( | |
repo_id=hf_dataset_path, | |
token=HF_TOKEN, | |
local_dir=f"/tmp/{random_id}", | |
repo_type="dataset", | |
) | |
return f"/tmp/{random_id}" | |
def process_dataset(dataset_dir: str): | |
# dataset dir consists of images, config.yaml and a metadata.jsonl (optional) with fields: file_name, prompt | |
# generate .txt files with the same name as the images with the prompt as the content | |
# remove metadata.jsonl | |
# return the path to the processed dataset | |
# check if config.yaml exists | |
if not os.path.exists(os.path.join(dataset_dir, "config.yaml")): | |
raise ValueError("config.yaml does not exist") | |
# check if metadata.jsonl exists | |
if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")): | |
metadata = [] | |
with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f: | |
for line in f: | |
if len(line.strip()) > 0: | |
metadata.append(json.loads(line)) | |
for item in metadata: | |
txt_path = os.path.join(dataset_dir, item["file_name"]) | |
txt_path = txt_path.rsplit(".", 1)[0] + ".txt" | |
with open(txt_path, "w") as f: | |
f.write(item["prompt"]) | |
# remove metadata.jsonl | |
os.remove(os.path.join(dataset_dir, "metadata.jsonl")) | |
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: | |
config = yaml.safe_load(f) | |
# update config with new dataset | |
config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir | |
with open(os.path.join(dataset_dir, "config.yaml"), "w") as f: | |
yaml.dump(config, f) | |
return dataset_dir | |
def run_training(hf_dataset_path: str): | |
dataset_dir = download_dataset(hf_dataset_path) | |
dataset_dir = process_dataset(dataset_dir) | |
# Force repo_id override in config.yaml | |
config_path = os.path.join(dataset_dir, "config.yaml") | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
config["config"]["process"][0]["save"]["hf_repo_id"] = repo_id | |
with open(config_path, "w") as f: | |
yaml.dump(config, f) | |
print("β Updated config.yaml with MODEL_REPO_ID:", repo_id) | |
# run training | |
if not os.path.exists("ai-toolkit"): | |
commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive" | |
shutil.rmtree(os.path.join(toolkit_src, ".git"), ignore_errors=True) | |
shutil.rmtree(os.path.join(toolkit_src, ".gitmodules"), ignore_errors=True) | |
subprocess.run(commands, shell=True) | |
# patch_ai_toolkit_typing() | |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}" | |
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,) | |
# Stream logs to Space output | |
for line in process.stdout: | |
sys.stdout.write(line) | |
sys.stdout.flush() | |
return process, dataset_dir | |
#def patch_ai_toolkit_typing(): | |
# config_path = "ai-toolkit/toolkit/config_modules.py" | |
# if os.path.exists(config_path): | |
# with open(config_path, "r") as f: | |
# content = f.read() | |
# content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]") | |
# with open(config_path, "w") as f: | |
# f.write(content) | |
# print("β Patched ai-toolkit typing for torch.Tensor | None β Optional[torch.Tensor]") | |
# else: | |
# print("β οΈ Could not patch config_modules.py β file not found") | |
if __name__ == "__main__": | |
try: | |
process, dataset_dir = run_training(HF_DATASET) | |
# process.wait() # Wait for the training process to finish | |
exit_code = process.wait() | |
print("Training finished with exit code:", exit_code) | |
if exit_code != 0: | |
raise RuntimeError(f"Training failed with exit code {exit_code}") | |
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: | |
config = yaml.safe_load(f) | |
#repo_id = config["config"]["process"][0]["save"]["hf_repo_id"] | |
#repo_id = os.environ.get("MODEL_REPO_ID") | |
#repo_id = os.getenv("MODEL_REPO_ID") | |
#repo_id = "DevWild/suppab0rk" | |
metadata = { | |
"tags": [ | |
"autotrain", | |
"spacerunner", | |
"text-to-image", | |
"flux", | |
"lora", | |
"diffusers", | |
"template:sd-lora", | |
] | |
} | |
metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True) | |
finally: | |
#delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True) | |
print("SCRIPT FINISHED, DATASET SHOULD BE DELETED") | |