File size: 3,811 Bytes
1c907cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee7e1de
 
1c907cb
 
ee7e1de
1c907cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import sys
import shutil
import argparse
import importlib
import subprocess

def parse_args():
    parser = argparse.ArgumentParser(description='Script configuration')
    parser.add_argument('--lang', type=str, default='en', help='Код языка, по умолчанию "en"')
    parser.add_argument('--repo', type=str, required=True, help='Repository Name')
    return parser.parse_args()

def detect_environment():
    free_plan = (os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') / (1024 ** 3) <= 20)
    environments = {
        'COLAB_GPU': ('Google Colab', "/root" if free_plan else "/content"),
        'KAGGLE_URL_BASE': ('Kaggle', "/kaggle/working/content")
    }
    for env_var, (environment, path) in environments.items():
        if env_var in os.environ:
            return environment, path, free_plan
    print("\033[31mError: an unsupported runtime environment was detected.\n\033[34mSupported environments:\033[0m Google Colab, Kaggle")
    return None, None, None

def setup_module_folder(root_path):
    modules_folder = os.path.join(root_path, "modules")
    os.makedirs(modules_folder, exist_ok=True)
    if modules_folder not in sys.path:
        sys.path.append(modules_folder)

def clear_module_cache(modules_folder):
    for module_name in list(sys.modules.keys()):
        module = sys.modules[module_name]
        if hasattr(module, '__file__') and module.__file__ and module.__file__.startswith(modules_folder):
            del sys.modules[module_name]
    
    importlib.invalidate_caches()

def download_files(root_path, lang, repo):
    print("Please wait for the files to download... 👀", end='', flush=True)
    files_dict = { # save folder name | url folder path | files list
        'CSS': {'CSS': ['main_widgets.css', 'auto_cleaner.css', 'dl_display_result.css']},
        'file_cell': {f'files_cells/python/{lang}': [f'widgets_{lang}.py', f'downloading_{lang}.py', f'launch_{lang}.py', f'auto_cleaner_{lang}.py']},
        'file_cell/special': {f'special': ['dl_display_results.py']},
        'modules': {f'modules': ['models_data.py', 'directory_setup.py']}
    }
    for folder, contents in files_dict.items():
        folder_path = os.path.join(root_path, folder)
        if os.path.exists(folder_path):
            shutil.rmtree(folder_path)
        os.makedirs(folder_path)
        for path_url, files in contents.items():
            for file in files:
                file_url = f"https://huggingface.co/NagisaNao/{repo}/resolve/main/{path_url}/{file}"
                file_path = os.path.join(folder_path, file)
                os.system(f'wget -q {file_url} -O {file_path}')
    print("\rDone! Now you can run the cells below. ☄️" + " "*30)

def main():
    args = parse_args()
    lang = args.lang
    repo = args.repo

    env, root_path, free_plan = detect_environment()

    if env and root_path:
        webui_path = f"{root_path}/sdw"
        download_files(root_path, lang, repo)
        clear_module_cache(os.path.join(root_path, "modules"))
        setup_module_folder(root_path)
        
        # Set global environment variable
        os.environ['ENV_NAME'] = env
        os.environ['ROOT_PATH'] = root_path
        os.environ['WEBUI_PATH'] = webui_path
        os.environ['FREE_PLAN'] = 'True' if free_plan else 'False'

        print(f"Runtime environment: \033[33m{env}\033[0m")
        if env == "Google Colab":
            print(f"Colab Pro subscription: \033[34m{not free_plan}\033[0m")
            print(f"File location: \033[32m{root_path}\033[0m")
            
        if repo != 'fast_repo':
            print('\n\033[31mWARNING: Test mode is used, there may be errors in use!\033[0m')

if __name__ == "__main__":
    main()