File size: 6,254 Bytes
5e37512
b1ec5b4
 
 
 
 
f905585
 
b1ec5b4
 
5e37512
 
 
 
 
b1ec5b4
 
5e37512
b7e5242
 
 
 
 
 
1d47eec
b1ec5b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e37512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285aab9
 
 
 
 
5e37512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1ec5b4
 
 
 
 
 
 
 
 
 
 
 
5e37512
 
46ec2c1
b1ec5b4
fd8dd64
 
46ec2c1
5e37512
fd8dd64
 
 
f905585
5e37512
b1ec5b4
5e37512
b1ec5b4
 
 
 
 
5e37512
 
f905585
 
 
 
 
5e37512
f905585
5e37512
f905585
 
 
 
 
b1ec5b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"

HF_TOKEN = os.environ.get("HF_TOKEN")
HF_DATASET = os.environ.get("HF_DATASET")
#HF_DATASET = "DevWild/autotrain-pr0b0rk"
repo_id = os.environ.get("MODEL_REPO_ID")

from huggingface_hub import snapshot_download, delete_repo, metadata_update
import uuid
import json
import yaml
import subprocess
import sys
from typing import Optional

#from huggingface_hub import login
#HF_TOKEN = os.getenv("HF_TOKEN")
#if HF_TOKEN:
#    login(token=HF_TOKEN)
#else:
#    raise ValueError("HF_TOKEN environment variable not found!")


if not HF_TOKEN:
    raise ValueError("Missing HF_TOKEN")

if not HF_DATASET:
    raise ValueError("Missing HF_DATASET")

if not repo_id:
    raise ValueError("Missing MODEL_REPO_ID")

# Prevent running script.py twice
LOCKFILE = "/tmp/.script_lock"
if os.path.exists(LOCKFILE):
    print("πŸ” Script already ran once β€” skipping.")
    exit(0)

with open(LOCKFILE, "w") as f:
    f.write("lock")

print("πŸš€ Running script for the first time")

# START logging
print("πŸš€ ENV DEBUG START")
print("HF_TOKEN present?", bool(HF_TOKEN))
print("HF_DATASET:", HF_DATASET)
print("MODEL_REPO_ID:", repo_id)
print("πŸš€ ENV DEBUG END")


#dataset_dir = snapshot_download(HF_DATASET, token=HF_TOKEN)


def download_dataset(hf_dataset_path: str):
    random_id = str(uuid.uuid4())
    snapshot_download(
        repo_id=hf_dataset_path,
        token=HF_TOKEN,
        local_dir=f"/tmp/{random_id}",
        repo_type="dataset",
    )
    return f"/tmp/{random_id}"


def process_dataset(dataset_dir: str):
    # dataset dir consists of images, config.yaml and a metadata.jsonl (optional) with fields: file_name, prompt
    # generate .txt files with the same name as the images with the prompt as the content
    # remove metadata.jsonl
    # return the path to the processed dataset

    # check if config.yaml exists
    if not os.path.exists(os.path.join(dataset_dir, "config.yaml")):
        raise ValueError("config.yaml does not exist")

    # check if metadata.jsonl exists
    if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")):
        metadata = []
        with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f:
            for line in f:
                if len(line.strip()) > 0:
                    metadata.append(json.loads(line))
        for item in metadata:
            txt_path = os.path.join(dataset_dir, item["file_name"])
            txt_path = txt_path.rsplit(".", 1)[0] + ".txt"
            with open(txt_path, "w") as f:
                f.write(item["prompt"])

        # remove metadata.jsonl
        os.remove(os.path.join(dataset_dir, "metadata.jsonl"))

    with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
        config = yaml.safe_load(f)

    # update config with new dataset
    config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir

    with open(os.path.join(dataset_dir, "config.yaml"), "w") as f:
        yaml.dump(config, f)

    return dataset_dir


def run_training(hf_dataset_path: str):

    dataset_dir = download_dataset(hf_dataset_path)
    dataset_dir = process_dataset(dataset_dir)
    # Force repo_id override in config.yaml
    config_path = os.path.join(dataset_dir, "config.yaml")
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    config["config"]["process"][0]["save"]["hf_repo_id"] = repo_id

    with open(config_path, "w") as f:
        yaml.dump(config, f)

    print("βœ… Updated config.yaml with MODEL_REPO_ID:", repo_id)


    # run training
    if not os.path.exists("ai-toolkit"):
        commands = "git clone https://github.com/DevW1ld/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive"
        shutil.rmtree(os.path.join(toolkit_src, ".git"), ignore_errors=True)
        shutil.rmtree(os.path.join(toolkit_src, ".gitmodules"), ignore_errors=True)
        subprocess.run(commands, shell=True)




    # patch_ai_toolkit_typing()
    commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
    process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,)

    # Stream logs to Space output
    for line in process.stdout:
        sys.stdout.write(line)
        sys.stdout.flush()
    
    return process, dataset_dir

#def patch_ai_toolkit_typing():
#    config_path = "ai-toolkit/toolkit/config_modules.py"
#    if os.path.exists(config_path):
#        with open(config_path, "r") as f:
#            content = f.read()

#        content = content.replace("torch.Tensor | None", "Optional[torch.Tensor]")

#        with open(config_path, "w") as f:
#            f.write(content)
#        print("βœ… Patched ai-toolkit typing for torch.Tensor | None β†’ Optional[torch.Tensor]")
#    else:
#        print("⚠️ Could not patch config_modules.py β€” file not found")


if __name__ == "__main__":
    try:
        process, dataset_dir = run_training(HF_DATASET)
    #    process.wait()  # Wait for the training process to finish
        exit_code = process.wait()
        print("Training finished with exit code:", exit_code)

        if exit_code != 0:
            raise RuntimeError(f"Training failed with exit code {exit_code}")

        with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
            config = yaml.safe_load(f)
        #repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]
        #repo_id = os.environ.get("MODEL_REPO_ID")
        #repo_id = os.getenv("MODEL_REPO_ID")
        #repo_id = "DevWild/suppab0rk"

        metadata = {
            "tags": [
                "autotrain",
                "spacerunner",
                "text-to-image",
                "flux",
                "lora",
                "diffusers",
                "template:sd-lora",
            ]
        }    
        metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True)

    finally:
        #delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True)
        print("SCRIPT FINISHED, DATASET SHOULD BE DELETED")