DevWild commited on
Commit
b1ec5b4
Β·
verified Β·
1 Parent(s): 7c7f4b7

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +107 -28
script.py CHANGED
@@ -1,9 +1,20 @@
1
  import os
 
 
 
 
 
 
 
 
 
2
  from huggingface_hub import snapshot_download, delete_repo, metadata_update
3
  import uuid
4
  import json
5
  import yaml
6
  import subprocess
 
 
7
 
8
  #from huggingface_hub import login
9
  #HF_TOKEN = os.getenv("HF_TOKEN")
@@ -12,9 +23,36 @@ import subprocess
12
  #else:
13
  # raise ValueError("HF_TOKEN environment variable not found!")
14
 
15
- HF_TOKEN = os.environ.get("HF_TOKEN")
16
- HF_DATASET = os.environ.get("HF_DATASET")
17
- dataset_dir = snapshot_download(HF_DATASET)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def download_dataset(hf_dataset_path: str):
@@ -70,39 +108,80 @@ def run_training(hf_dataset_path: str):
70
 
71
  dataset_dir = download_dataset(hf_dataset_path)
72
  dataset_dir = process_dataset(dataset_dir)
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # run training
75
  if not os.path.exists("ai-toolkit"):
76
- commands = "git clone https://github.com/ostris/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
77
  subprocess.run(commands, shell=True)
78
 
 
79
  commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
80
- process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ)
81
 
 
 
 
 
 
82
  return process, dataset_dir
83
 
 
 
 
 
 
84
 
85
- if __name__ == "__main__":
86
- process, dataset_dir = run_training(HF_DATASET)
87
- process.wait() # Wait for the training process to finish
88
 
89
- with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
90
- config = yaml.safe_load(f)
91
- #repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]
92
- #repo_id = os.environ.get("MODEL_REPO_ID")
93
- repo_id = os.getenv("MODEL_REPO_ID")
94
- #MODEL_REPO_ID = os.getenv("MODEL_REPO_ID")
95
-
96
- metadata = {
97
- "tags": [
98
- "autotrain",
99
- "spacerunner",
100
- "text-to-image",
101
- "flux",
102
- "lora",
103
- "diffusers",
104
- "template:sd-lora",
105
- ]
106
- }
107
- metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True)
108
- delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ os.environ["HF_HOME"] = "/tmp/huggingface"
3
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
4
+ os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
5
+
6
+ HF_TOKEN = os.environ.get("HF_TOKEN")
7
+ #HF_DATASET = os.environ.get("HF_DATASET")
8
+ HF_DATASET = "DevWild/autotrain-pr0b0rk"
9
+ repo_id = os.environ.get("MODEL_REPO_ID")
10
+
11
  from huggingface_hub import snapshot_download, delete_repo, metadata_update
12
  import uuid
13
  import json
14
  import yaml
15
  import subprocess
16
+ import sys
17
+ from typing import Optional
18
 
19
  #from huggingface_hub import login
20
  #HF_TOKEN = os.getenv("HF_TOKEN")
 
23
  #else:
24
  # raise ValueError("HF_TOKEN environment variable not found!")
25
 
26
+
27
+ if not HF_TOKEN:
28
+ raise ValueError("Missing HF_TOKEN")
29
+
30
+ if not HF_DATASET:
31
+ raise ValueError("Missing HF_DATASET")
32
+
33
+ if not repo_id:
34
+ raise ValueError("Missing MODEL_REPO_ID")
35
+
36
+ # Prevent running script.py twice
37
+ LOCKFILE = "/tmp/.script_lock"
38
+ if os.path.exists(LOCKFILE):
39
+ print("πŸ” Script already ran once β€” skipping.")
40
+ exit(0)
41
+
42
+ with open(LOCKFILE, "w") as f:
43
+ f.write("lock")
44
+
45
+ print("πŸš€ Running script for the first time")
46
+
47
+ # START logging
48
+ print("πŸš€ ENV DEBUG START")
49
+ print("HF_TOKEN present?", bool(HF_TOKEN))
50
+ print("HF_DATASET:", HF_DATASET)
51
+ print("MODEL_REPO_ID:", repo_id)
52
+ print("πŸš€ ENV DEBUG END")
53
+
54
+
55
+ #dataset_dir = snapshot_download(HF_DATASET, token=HF_TOKEN)
56
 
57
 
58
  def download_dataset(hf_dataset_path: str):
 
108
 
109
  dataset_dir = download_dataset(hf_dataset_path)
110
  dataset_dir = process_dataset(dataset_dir)
111
+ # Force repo_id override in config.yaml
112
+ config_path = os.path.join(dataset_dir, "config.yaml")
113
+ with open(config_path, "r") as f:
114
+ config = yaml.safe_load(f)
115
+
116
+ config["config"]["process"][0]["save"]["hf_repo_id"] = repo_id
117
+
118
+ with open(config_path, "w") as f:
119
+ yaml.dump(config, f)
120
+
121
+ print("βœ… Updated config.yaml with MODEL_REPO_ID:", repo_id)
122
+
123
 
124
  # run training
125
  if not os.path.exists("ai-toolkit"):
126
+ commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
127
  subprocess.run(commands, shell=True)
128
 
129
+ patch_ai_toolkit_typing()
130
  commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
131
+ process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)
132
 
133
+ # Stream logs to Space output
134
+ for line in process.stdout:
135
+ sys.stdout.write(line)
136
+ sys.stdout.flush()
137
+
138
  return process, dataset_dir
139
 
140
+ def patch_ai_toolkit_typing():
141
+ config_path = "ai-toolkit/toolkit/config_modules.py"
142
+ if os.path.exists(config_path):
143
+ with open(config_path, "r") as f:
144
+ content = f.read()
145
 
146
+ content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")
 
 
147
 
148
+ with open(config_path, "w") as f:
149
+ f.write(content)
150
+ print("βœ… Patched ai-toolkit typing for torch.Tensor | None β†’ Optional[torch.Tensor]")
151
+ else:
152
+ print("⚠️ Could not patch config_modules.py β€” file not found")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ try:
157
+ process, dataset_dir = run_training(HF_DATASET)
158
+ # process.wait() # Wait for the training process to finish
159
+ exit_code = process.wait()
160
+ print("Training finished with exit code:", exit_code)
161
+
162
+ if exit_code != 0:
163
+ raise RuntimeError(f"Training failed with exit code {exit_code}")
164
+
165
+ with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
166
+ config = yaml.safe_load(f)
167
+ #repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]
168
+ #repo_id = os.environ.get("MODEL_REPO_ID")
169
+ #repo_id = os.getenv("MODEL_REPO_ID")
170
+ #repo_id = "DevWild/suppab0rk"
171
+
172
+ metadata = {
173
+ "tags": [
174
+ "autotrain",
175
+ "spacerunner",
176
+ "text-to-image",
177
+ "flux",
178
+ "lora",
179
+ "diffusers",
180
+ "template:sd-lora",
181
+ ]
182
+ }
183
+ metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True)
184
+
185
+ finally:
186
+ #delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True)
187
+ print("SCRIPT FINISHED, DATASET SHOULD BE DELETED")