Spaces:
Running
Running
Update script.py
Browse files
script.py
CHANGED
@@ -1,9 +1,20 @@
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from huggingface_hub import snapshot_download, delete_repo, metadata_update
|
3 |
import uuid
|
4 |
import json
|
5 |
import yaml
|
6 |
import subprocess
|
|
|
|
|
7 |
|
8 |
#from huggingface_hub import login
|
9 |
#HF_TOKEN = os.getenv("HF_TOKEN")
|
@@ -12,9 +23,36 @@ import subprocess
|
|
12 |
#else:
|
13 |
# raise ValueError("HF_TOKEN environment variable not found!")
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
def download_dataset(hf_dataset_path: str):
|
@@ -70,39 +108,80 @@ def run_training(hf_dataset_path: str):
|
|
70 |
|
71 |
dataset_dir = download_dataset(hf_dataset_path)
|
72 |
dataset_dir = process_dataset(dataset_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
# run training
|
75 |
if not os.path.exists("ai-toolkit"):
|
76 |
-
commands = "git clone https://github.com/
|
77 |
subprocess.run(commands, shell=True)
|
78 |
|
|
|
79 |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
|
80 |
-
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ)
|
81 |
|
|
|
|
|
|
|
|
|
|
|
82 |
return process, dataset_dir
|
83 |
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
process, dataset_dir = run_training(HF_DATASET)
|
87 |
-
process.wait() # Wait for the training process to finish
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
"
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
os.environ["HF_HOME"] = "/tmp/huggingface"
|
3 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
4 |
+
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
5 |
+
|
6 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
7 |
+
#HF_DATASET = os.environ.get("HF_DATASET")
|
8 |
+
HF_DATASET = "DevWild/autotrain-pr0b0rk"
|
9 |
+
repo_id = os.environ.get("MODEL_REPO_ID")
|
10 |
+
|
11 |
from huggingface_hub import snapshot_download, delete_repo, metadata_update
|
12 |
import uuid
|
13 |
import json
|
14 |
import yaml
|
15 |
import subprocess
|
16 |
+
import sys
|
17 |
+
from typing import Optional
|
18 |
|
19 |
#from huggingface_hub import login
|
20 |
#HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
23 |
#else:
|
24 |
# raise ValueError("HF_TOKEN environment variable not found!")
|
25 |
|
26 |
+
|
27 |
+
if not HF_TOKEN:
|
28 |
+
raise ValueError("Missing HF_TOKEN")
|
29 |
+
|
30 |
+
if not HF_DATASET:
|
31 |
+
raise ValueError("Missing HF_DATASET")
|
32 |
+
|
33 |
+
if not repo_id:
|
34 |
+
raise ValueError("Missing MODEL_REPO_ID")
|
35 |
+
|
36 |
+
# Prevent running script.py twice
|
37 |
+
LOCKFILE = "/tmp/.script_lock"
|
38 |
+
if os.path.exists(LOCKFILE):
|
39 |
+
print("π Script already ran once β skipping.")
|
40 |
+
exit(0)
|
41 |
+
|
42 |
+
with open(LOCKFILE, "w") as f:
|
43 |
+
f.write("lock")
|
44 |
+
|
45 |
+
print("π Running script for the first time")
|
46 |
+
|
47 |
+
# START logging
|
48 |
+
print("π ENV DEBUG START")
|
49 |
+
print("HF_TOKEN present?", bool(HF_TOKEN))
|
50 |
+
print("HF_DATASET:", HF_DATASET)
|
51 |
+
print("MODEL_REPO_ID:", repo_id)
|
52 |
+
print("π ENV DEBUG END")
|
53 |
+
|
54 |
+
|
55 |
+
#dataset_dir = snapshot_download(HF_DATASET, token=HF_TOKEN)
|
56 |
|
57 |
|
58 |
def download_dataset(hf_dataset_path: str):
|
|
|
108 |
|
109 |
dataset_dir = download_dataset(hf_dataset_path)
|
110 |
dataset_dir = process_dataset(dataset_dir)
|
111 |
+
# Force repo_id override in config.yaml
|
112 |
+
config_path = os.path.join(dataset_dir, "config.yaml")
|
113 |
+
with open(config_path, "r") as f:
|
114 |
+
config = yaml.safe_load(f)
|
115 |
+
|
116 |
+
config["config"]["process"][0]["save"]["hf_repo_id"] = repo_id
|
117 |
+
|
118 |
+
with open(config_path, "w") as f:
|
119 |
+
yaml.dump(config, f)
|
120 |
+
|
121 |
+
print("β
Updated config.yaml with MODEL_REPO_ID:", repo_id)
|
122 |
+
|
123 |
|
124 |
# run training
|
125 |
if not os.path.exists("ai-toolkit"):
|
126 |
+
commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
|
127 |
subprocess.run(commands, shell=True)
|
128 |
|
129 |
+
patch_ai_toolkit_typing()
|
130 |
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
|
131 |
+
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
|
132 |
|
133 |
+
# Stream logs to Space output
|
134 |
+
for line in process.stdout:
|
135 |
+
sys.stdout.write(line)
|
136 |
+
sys.stdout.flush()
|
137 |
+
|
138 |
return process, dataset_dir
|
139 |
|
140 |
+
def patch_ai_toolkit_typing():
|
141 |
+
config_path = "ai-toolkit/toolkit/config_modules.py"
|
142 |
+
if os.path.exists(config_path):
|
143 |
+
with open(config_path, "r") as f:
|
144 |
+
content = f.read()
|
145 |
|
146 |
+
content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
|
|
|
|
|
147 |
|
148 |
+
with open(config_path, "w") as f:
|
149 |
+
f.write(content)
|
150 |
+
print("β
Patched ai-toolkit typing for torch.Tensor | None β Optional[torch.Tensor]")
|
151 |
+
else:
|
152 |
+
print("β οΈ Could not patch config_modules.py β file not found")
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
try:
|
157 |
+
process, dataset_dir = run_training(HF_DATASET)
|
158 |
+
# process.wait() # Wait for the training process to finish
|
159 |
+
exit_code = process.wait()
|
160 |
+
print("Training finished with exit code:", exit_code)
|
161 |
+
|
162 |
+
if exit_code != 0:
|
163 |
+
raise RuntimeError(f"Training failed with exit code {exit_code}")
|
164 |
+
|
165 |
+
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
|
166 |
+
config = yaml.safe_load(f)
|
167 |
+
#repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]
|
168 |
+
#repo_id = os.environ.get("MODEL_REPO_ID")
|
169 |
+
#repo_id = os.getenv("MODEL_REPO_ID")
|
170 |
+
#repo_id = "DevWild/suppab0rk"
|
171 |
+
|
172 |
+
metadata = {
|
173 |
+
"tags": [
|
174 |
+
"autotrain",
|
175 |
+
"spacerunner",
|
176 |
+
"text-to-image",
|
177 |
+
"flux",
|
178 |
+
"lora",
|
179 |
+
"diffusers",
|
180 |
+
"template:sd-lora",
|
181 |
+
]
|
182 |
+
}
|
183 |
+
metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True)
|
184 |
+
|
185 |
+
finally:
|
186 |
+
#delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True)
|
187 |
+
print("SCRIPT FINISHED, DATASET SHOULD BE DELETED")
|