Spaces:
Running
Running
File size: 6,254 Bytes
5e37512 b1ec5b4 f905585 b1ec5b4 5e37512 b1ec5b4 5e37512 b7e5242 1d47eec b1ec5b4 5e37512 285aab9 5e37512 b1ec5b4 5e37512 46ec2c1 b1ec5b4 fd8dd64 46ec2c1 5e37512 fd8dd64 f905585 5e37512 b1ec5b4 5e37512 b1ec5b4 5e37512 f905585 5e37512 f905585 5e37512 f905585 b1ec5b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
HF_TOKEN = os.environ.get("HF_TOKEN")
HF_DATASET = os.environ.get("HF_DATASET")
#HF_DATASET = "DevWild/autotrain-pr0b0rk"
repo_id = os.environ.get("MODEL_REPO_ID")
from huggingface_hub import snapshot_download, delete_repo, metadata_update
import uuid
import json
import yaml
import subprocess
import sys
from typing import Optional
#from huggingface_hub import login
#HF_TOKEN = os.getenv("HF_TOKEN")
#if HF_TOKEN:
# login(token=HF_TOKEN)
#else:
# raise ValueError("HF_TOKEN environment variable not found!")
if not HF_TOKEN:
raise ValueError("Missing HF_TOKEN")
if not HF_DATASET:
raise ValueError("Missing HF_DATASET")
if not repo_id:
raise ValueError("Missing MODEL_REPO_ID")
# Prevent running script.py twice
LOCKFILE = "/tmp/.script_lock"
if os.path.exists(LOCKFILE):
print("π Script already ran once β skipping.")
exit(0)
with open(LOCKFILE, "w") as f:
f.write("lock")
print("π Running script for the first time")
# START logging
print("π ENV DEBUG START")
print("HF_TOKEN present?", bool(HF_TOKEN))
print("HF_DATASET:", HF_DATASET)
print("MODEL_REPO_ID:", repo_id)
print("π ENV DEBUG END")
#dataset_dir = snapshot_download(HF_DATASET, token=HF_TOKEN)
def download_dataset(hf_dataset_path: str):
random_id = str(uuid.uuid4())
snapshot_download(
repo_id=hf_dataset_path,
token=HF_TOKEN,
local_dir=f"/tmp/{random_id}",
repo_type="dataset",
)
return f"/tmp/{random_id}"
def process_dataset(dataset_dir: str):
# dataset dir consists of images, config.yaml and a metadata.jsonl (optional) with fields: file_name, prompt
# generate .txt files with the same name as the images with the prompt as the content
# remove metadata.jsonl
# return the path to the processed dataset
# check if config.yaml exists
if not os.path.exists(os.path.join(dataset_dir, "config.yaml")):
raise ValueError("config.yaml does not exist")
# check if metadata.jsonl exists
if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")):
metadata = []
with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f:
for line in f:
if len(line.strip()) > 0:
metadata.append(json.loads(line))
for item in metadata:
txt_path = os.path.join(dataset_dir, item["file_name"])
txt_path = txt_path.rsplit(".", 1)[0] + ".txt"
with open(txt_path, "w") as f:
f.write(item["prompt"])
# remove metadata.jsonl
os.remove(os.path.join(dataset_dir, "metadata.jsonl"))
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
config = yaml.safe_load(f)
# update config with new dataset
config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir
with open(os.path.join(dataset_dir, "config.yaml"), "w") as f:
yaml.dump(config, f)
return dataset_dir
def run_training(hf_dataset_path: str):
dataset_dir = download_dataset(hf_dataset_path)
dataset_dir = process_dataset(dataset_dir)
# Force repo_id override in config.yaml
config_path = os.path.join(dataset_dir, "config.yaml")
with open(config_path, "r") as f:
config = yaml.safe_load(f)
config["config"]["process"][0]["save"]["hf_repo_id"] = repo_id
with open(config_path, "w") as f:
yaml.dump(config, f)
print("β
Updated config.yaml with MODEL_REPO_ID:", repo_id)
# run training
if not os.path.exists("ai-toolkit"):
commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
shutil.rmtree(os.path.join(toolkit_src, ".git"), ignore_errors=True)
shutil.rmtree(os.path.join(toolkit_src, ".gitmodules"), ignore_errors=True)
subprocess.run(commands, shell=True)
# patch_ai_toolkit_typing()
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
# Stream logs to Space output
for line in process.stdout:
sys.stdout.write(line)
sys.stdout.flush()
return process, dataset_dir
#def patch_ai_toolkit_typing():
# config_path = "ai-toolkit/toolkit/config_modules.py"
# if os.path.exists(config_path):
# with open(config_path, "r") as f:
# content = f.read()
# content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
# with open(config_path, "w") as f:
# f.write(content)
# print("β
Patched ai-toolkit typing for torch.Tensor | None β Optional[torch.Tensor]")
# else:
# print("β οΈ Could not patch config_modules.py β file not found")
if __name__ == "__main__":
try:
process, dataset_dir = run_training(HF_DATASET)
# process.wait() # Wait for the training process to finish
exit_code = process.wait()
print("Training finished with exit code:", exit_code)
if exit_code != 0:
raise RuntimeError(f"Training failed with exit code {exit_code}")
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
config = yaml.safe_load(f)
#repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]
#repo_id = os.environ.get("MODEL_REPO_ID")
#repo_id = os.getenv("MODEL_REPO_ID")
#repo_id = "DevWild/suppab0rk"
metadata = {
"tags": [
"autotrain",
"spacerunner",
"text-to-image",
"flux",
"lora",
"diffusers",
"template:sd-lora",
]
}
metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True)
finally:
#delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True)
print("SCRIPT FINISHED, DATASET SHOULD BE DELETED")
|