Commit
•
a3f0757
1
Parent(s):
cadb289
simplified-ui (#3)
Browse files- Simplified UI (17a686c02af144dbc69aabd538b946bb34cb3c4e)
Co-authored-by: Multimodal AI art <[email protected]>
- app.py +21 -12
- app_training.py +64 -64
- app_upload.py +5 -1
- trainer.py +13 -8
- uploader.py +14 -12
app.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import os
|
|
|
6 |
|
7 |
import gradio as gr
|
8 |
import torch
|
@@ -13,27 +14,31 @@ from app_upload import create_upload_demo
|
|
13 |
from inference import InferencePipeline
|
14 |
from trainer import Trainer
|
15 |
|
16 |
-
TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)
|
17 |
|
18 |
ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
|
19 |
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
|
20 |
-
|
|
|
21 |
|
22 |
-
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
|
23 |
'''
|
24 |
|
25 |
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
|
26 |
SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
|
27 |
else:
|
28 |
SETTINGS = 'Settings'
|
29 |
-
|
|
|
|
|
|
|
30 |
<center>
|
31 |
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
|
32 |
-
You can use "T4 small/medium"
|
33 |
</center>
|
34 |
'''
|
35 |
|
36 |
-
HF_TOKEN_NOT_SPECIFIED_WARNING = f'''
|
37 |
<center>
|
38 |
You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
|
39 |
You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
|
@@ -54,23 +59,27 @@ pipe = InferencePipeline(HF_TOKEN)
|
|
54 |
trainer = Trainer(HF_TOKEN)
|
55 |
|
56 |
with gr.Blocks(css='style.css') as demo:
|
57 |
-
if
|
58 |
show_warning(SHARED_UI_WARNING)
|
59 |
-
|
60 |
show_warning(CUDA_NOT_AVAILABLE_WARNING)
|
61 |
-
|
62 |
-
show_warning(
|
|
|
63 |
|
64 |
gr.Markdown(TITLE)
|
65 |
with gr.Tabs():
|
66 |
with gr.TabItem('Train'):
|
67 |
create_training_demo(trainer, pipe)
|
68 |
-
with gr.TabItem('
|
69 |
create_inference_demo(pipe, HF_TOKEN)
|
70 |
with gr.TabItem('Upload'):
|
71 |
gr.Markdown('''
|
72 |
- You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
|
73 |
''')
|
74 |
create_upload_demo(HF_TOKEN)
|
|
|
|
|
|
|
75 |
|
76 |
-
demo.queue(max_size=1).launch(share=False)
|
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import os
|
6 |
+
from subprocess import getoutput
|
7 |
|
8 |
import gradio as gr
|
9 |
import torch
|
|
|
14 |
from inference import InferencePipeline
|
15 |
from trainer import Trainer
|
16 |
|
17 |
+
TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) UI'
|
18 |
|
19 |
ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
|
20 |
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
|
21 |
+
GPU_DATA = getoutput('nvidia-smi')
|
22 |
+
SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
|
23 |
|
24 |
+
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
|
25 |
'''
|
26 |
|
27 |
if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
|
28 |
SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
|
29 |
else:
|
30 |
SETTINGS = 'Settings'
|
31 |
+
|
32 |
+
INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
|
33 |
+
|
34 |
+
CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
|
35 |
<center>
|
36 |
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
|
37 |
+
You can use "T4 small/medium" to run this demo.
|
38 |
</center>
|
39 |
'''
|
40 |
|
41 |
+
HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
|
42 |
<center>
|
43 |
You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
|
44 |
You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
|
|
|
59 |
trainer = Trainer(HF_TOKEN)
|
60 |
|
61 |
with gr.Blocks(css='style.css') as demo:
|
62 |
+
if SPACE_ID == ORIGINAL_SPACE_ID:
|
63 |
show_warning(SHARED_UI_WARNING)
|
64 |
+
elif not torch.cuda.is_available():
|
65 |
show_warning(CUDA_NOT_AVAILABLE_WARNING)
|
66 |
+
elif(not "T4" in GPU_DATA):
|
67 |
+
show_warning(INVALID_GPU_WARNING)
|
68 |
+
|
69 |
|
70 |
gr.Markdown(TITLE)
|
71 |
with gr.Tabs():
|
72 |
with gr.TabItem('Train'):
|
73 |
create_training_demo(trainer, pipe)
|
74 |
+
with gr.TabItem('Run'):
|
75 |
create_inference_demo(pipe, HF_TOKEN)
|
76 |
with gr.TabItem('Upload'):
|
77 |
gr.Markdown('''
|
78 |
- You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
|
79 |
''')
|
80 |
create_upload_demo(HF_TOKEN)
|
81 |
+
|
82 |
+
if not HF_TOKEN:
|
83 |
+
show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
|
84 |
|
85 |
+
demo.queue(max_size=1).launch(share=False)
|
app_training.py
CHANGED
@@ -13,6 +13,7 @@ from trainer import Trainer
|
|
13 |
|
14 |
def create_training_demo(trainer: Trainer,
|
15 |
pipe: InferencePipeline | None = None) -> gr.Blocks:
|
|
|
16 |
with gr.Blocks() as demo:
|
17 |
with gr.Row():
|
18 |
with gr.Column():
|
@@ -24,82 +25,80 @@ def create_training_demo(trainer: Trainer,
|
|
24 |
max_lines=1,
|
25 |
placeholder='A man is surfing')
|
26 |
gr.Markdown('''
|
27 |
-
- Upload a video and write a
|
28 |
''')
|
|
|
|
|
29 |
with gr.Box():
|
30 |
-
gr.Markdown('
|
31 |
-
output_model_name = gr.Text(label='Name of your model',
|
32 |
-
max_lines=1)
|
33 |
-
delete_existing_model = gr.Checkbox(
|
34 |
-
label='Delete existing model of the same name',
|
35 |
-
value=False)
|
36 |
-
validation_prompt = gr.Text(label='Validation Prompt')
|
37 |
-
with gr.Box():
|
38 |
-
gr.Markdown('Upload Settings')
|
39 |
with gr.Row():
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
upload_to = gr.Radio(
|
48 |
label='Upload to',
|
49 |
choices=[_.value for _ in UploadTarget],
|
50 |
value=UploadTarget.MODEL_LIBRARY.value)
|
51 |
-
|
52 |
-
- By default, trained models will be uploaded to [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (see [this example model](https://huggingface.co/{SAMPLE_MODEL_REPO})).
|
53 |
-
- You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{{your_username}}/{{model_name}}.
|
54 |
-
''')
|
55 |
-
|
56 |
-
with gr.Box():
|
57 |
-
gr.Markdown('Training Parameters')
|
58 |
-
with gr.Row():
|
59 |
-
base_model = gr.Text(label='Base Model',
|
60 |
-
value='CompVis/stable-diffusion-v1-4',
|
61 |
-
max_lines=1)
|
62 |
-
resolution = gr.Dropdown(choices=['512', '768'],
|
63 |
-
value='512',
|
64 |
-
label='Resolution',
|
65 |
-
visible=False)
|
66 |
-
num_training_steps = gr.Number(
|
67 |
-
label='Number of Training Steps', value=300, precision=0)
|
68 |
-
learning_rate = gr.Number(label='Learning Rate',
|
69 |
-
value=0.000035)
|
70 |
-
gradient_accumulation = gr.Number(
|
71 |
-
label='Number of Gradient Accumulation',
|
72 |
-
value=1,
|
73 |
-
precision=0)
|
74 |
-
seed = gr.Slider(label='Seed',
|
75 |
-
minimum=0,
|
76 |
-
maximum=100000,
|
77 |
-
step=1,
|
78 |
-
value=0)
|
79 |
-
fp16 = gr.Checkbox(label='FP16', value=True)
|
80 |
-
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
|
81 |
-
checkpointing_steps = gr.Number(label='Checkpointing Steps',
|
82 |
-
value=1000,
|
83 |
-
precision=0)
|
84 |
-
validation_epochs = gr.Number(label='Validation Epochs',
|
85 |
-
value=100,
|
86 |
-
precision=0)
|
87 |
-
gr.Markdown('''
|
88 |
-
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
|
89 |
-
- It takes a few minutes to download the base model first.
|
90 |
-
- Expected time to train a model for 300 steps: 20 minutes with T4, 8 minutes with A10G, (4 minutes with A100)
|
91 |
-
- It takes a few minutes to upload your trained model.
|
92 |
-
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
93 |
-
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
|
94 |
-
''')
|
95 |
-
|
96 |
remove_gpu_after_training = gr.Checkbox(
|
97 |
label='Remove GPU after training',
|
98 |
value=False,
|
99 |
interactive=bool(os.getenv('SPACE_ID')),
|
100 |
visible=False)
|
101 |
run_button = gr.Button('Start Training')
|
102 |
-
|
103 |
with gr.Box():
|
104 |
gr.Markdown('Output message')
|
105 |
output_message = gr.Markdown()
|
@@ -111,7 +110,7 @@ def create_training_demo(trainer: Trainer,
|
|
111 |
training_video,
|
112 |
training_prompt,
|
113 |
output_model_name,
|
114 |
-
|
115 |
validation_prompt,
|
116 |
base_model,
|
117 |
resolution,
|
@@ -128,6 +127,7 @@ def create_training_demo(trainer: Trainer,
|
|
128 |
delete_existing_repo,
|
129 |
upload_to,
|
130 |
remove_gpu_after_training,
|
|
|
131 |
],
|
132 |
outputs=output_message)
|
133 |
return demo
|
|
|
13 |
|
14 |
def create_training_demo(trainer: Trainer,
|
15 |
pipe: InferencePipeline | None = None) -> gr.Blocks:
|
16 |
+
hf_token = os.getenv('HF_TOKEN')
|
17 |
with gr.Blocks() as demo:
|
18 |
with gr.Row():
|
19 |
with gr.Column():
|
|
|
25 |
max_lines=1,
|
26 |
placeholder='A man is surfing')
|
27 |
gr.Markdown('''
|
28 |
+
- Upload a video and write a `Training Prompt` that describes the video.
|
29 |
''')
|
30 |
+
|
31 |
+
with gr.Column():
|
32 |
with gr.Box():
|
33 |
+
gr.Markdown('Training Parameters')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
with gr.Row():
|
35 |
+
base_model = gr.Text(label='Base Model',
|
36 |
+
value='CompVis/stable-diffusion-v1-4',
|
37 |
+
max_lines=1)
|
38 |
+
resolution = gr.Dropdown(choices=['512', '768'],
|
39 |
+
value='512',
|
40 |
+
label='Resolution',
|
41 |
+
visible=False)
|
42 |
+
|
43 |
+
input_token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=False if hf_token else True)
|
44 |
+
with gr.Accordion("Advanced settings", open=False):
|
45 |
+
num_training_steps = gr.Number(
|
46 |
+
label='Number of Training Steps', value=300, precision=0)
|
47 |
+
learning_rate = gr.Number(label='Learning Rate',
|
48 |
+
value=0.000035)
|
49 |
+
gradient_accumulation = gr.Number(
|
50 |
+
label='Number of Gradient Accumulation',
|
51 |
+
value=1,
|
52 |
+
precision=0)
|
53 |
+
seed = gr.Slider(label='Seed',
|
54 |
+
minimum=0,
|
55 |
+
maximum=100000,
|
56 |
+
step=1,
|
57 |
+
randomize=True,
|
58 |
+
value=0)
|
59 |
+
fp16 = gr.Checkbox(label='FP16', value=True)
|
60 |
+
use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
|
61 |
+
checkpointing_steps = gr.Number(label='Checkpointing Steps',
|
62 |
+
value=1000,
|
63 |
+
precision=0)
|
64 |
+
validation_epochs = gr.Number(label='Validation Epochs',
|
65 |
+
value=100,
|
66 |
+
precision=0)
|
67 |
+
gr.Markdown('''
|
68 |
+
- The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
|
69 |
+
- Expected time to train a model for 300 steps: ~20 minutes with T4
|
70 |
+
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
|
71 |
+
''')
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
with gr.Column():
|
75 |
+
gr.Markdown('Output Model')
|
76 |
+
output_model_name = gr.Text(label='Name of your model',
|
77 |
+
placeholder='The surfer man',
|
78 |
+
max_lines=1)
|
79 |
+
validation_prompt = gr.Text(label='Validation Prompt', placeholder='prompt to test the model, e.g: a dog is surfing')
|
80 |
+
with gr.Column():
|
81 |
+
gr.Markdown('Upload Settings')
|
82 |
+
with gr.Row():
|
83 |
+
upload_to_hub = gr.Checkbox(
|
84 |
+
label='Upload model to Hub', value=True)
|
85 |
+
use_private_repo = gr.Checkbox(label='Private',
|
86 |
+
value=True)
|
87 |
+
delete_existing_repo = gr.Checkbox(
|
88 |
+
label='Delete existing repo of the same name',
|
89 |
+
value=False)
|
90 |
upload_to = gr.Radio(
|
91 |
label='Upload to',
|
92 |
choices=[_.value for _ in UploadTarget],
|
93 |
value=UploadTarget.MODEL_LIBRARY.value)
|
94 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
remove_gpu_after_training = gr.Checkbox(
|
96 |
label='Remove GPU after training',
|
97 |
value=False,
|
98 |
interactive=bool(os.getenv('SPACE_ID')),
|
99 |
visible=False)
|
100 |
run_button = gr.Button('Start Training')
|
101 |
+
|
102 |
with gr.Box():
|
103 |
gr.Markdown('Output message')
|
104 |
output_message = gr.Markdown()
|
|
|
110 |
training_video,
|
111 |
training_prompt,
|
112 |
output_model_name,
|
113 |
+
delete_existing_repo,
|
114 |
validation_prompt,
|
115 |
base_model,
|
116 |
resolution,
|
|
|
127 |
delete_existing_repo,
|
128 |
upload_to,
|
129 |
remove_gpu_after_training,
|
130 |
+
input_token
|
131 |
],
|
132 |
outputs=output_message)
|
133 |
return demo
|
app_upload.py
CHANGED
@@ -20,6 +20,7 @@ class ModelUploader(Uploader):
|
|
20 |
upload_to: str,
|
21 |
private: bool,
|
22 |
delete_existing_repo: bool,
|
|
|
23 |
) -> str:
|
24 |
if not folder_path:
|
25 |
raise ValueError
|
@@ -38,7 +39,8 @@ class ModelUploader(Uploader):
|
|
38 |
repo_name,
|
39 |
organization=organization,
|
40 |
private=private,
|
41 |
-
delete_existing_repo=delete_existing_repo
|
|
|
42 |
|
43 |
|
44 |
def load_local_model_list() -> dict:
|
@@ -68,6 +70,7 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
|
|
68 |
choices=[_.value for _ in UploadTarget],
|
69 |
value=UploadTarget.MODEL_LIBRARY.value)
|
70 |
model_name = gr.Textbox(label='Model Name')
|
|
|
71 |
upload_button = gr.Button('Upload')
|
72 |
gr.Markdown(f'''
|
73 |
- You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
|
@@ -86,6 +89,7 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
|
|
86 |
upload_to,
|
87 |
use_private_repo,
|
88 |
delete_existing_repo,
|
|
|
89 |
],
|
90 |
outputs=output_message)
|
91 |
|
|
|
20 |
upload_to: str,
|
21 |
private: bool,
|
22 |
delete_existing_repo: bool,
|
23 |
+
input_token: str | None = None,
|
24 |
) -> str:
|
25 |
if not folder_path:
|
26 |
raise ValueError
|
|
|
39 |
repo_name,
|
40 |
organization=organization,
|
41 |
private=private,
|
42 |
+
delete_existing_repo=delete_existing_repo,
|
43 |
+
input_token=input_token)
|
44 |
|
45 |
|
46 |
def load_local_model_list() -> dict:
|
|
|
70 |
choices=[_.value for _ in UploadTarget],
|
71 |
value=UploadTarget.MODEL_LIBRARY.value)
|
72 |
model_name = gr.Textbox(label='Model Name')
|
73 |
+
input_token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=False if hf_token else True)
|
74 |
upload_button = gr.Button('Upload')
|
75 |
gr.Markdown(f'''
|
76 |
- You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
|
|
|
89 |
upload_to,
|
90 |
use_private_repo,
|
91 |
delete_existing_repo,
|
92 |
+
input_token,
|
93 |
],
|
94 |
outputs=output_message)
|
95 |
|
trainer.py
CHANGED
@@ -20,12 +20,12 @@ from utils import save_model_card
|
|
20 |
sys.path.append('Tune-A-Video')
|
21 |
|
22 |
URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
|
23 |
-
|
|
|
24 |
|
25 |
class Trainer:
|
26 |
def __init__(self, hf_token: str | None = None):
|
27 |
self.hf_token = hf_token
|
28 |
-
self.api = HfApi(token=hf_token)
|
29 |
self.model_uploader = ModelUploader(hf_token)
|
30 |
|
31 |
self.checkpoint_dir = pathlib.Path('checkpoints')
|
@@ -42,10 +42,10 @@ class Trainer:
|
|
42 |
cwd=org_dir)
|
43 |
return model_dir.as_posix()
|
44 |
|
45 |
-
def join_model_library_org(self) -> None:
|
46 |
subprocess.run(
|
47 |
shlex.split(
|
48 |
-
f'curl -X POST -H "Authorization: Bearer {
|
49 |
))
|
50 |
|
51 |
def run(
|
@@ -70,7 +70,10 @@ class Trainer:
|
|
70 |
delete_existing_repo: bool,
|
71 |
upload_to: str,
|
72 |
remove_gpu_after_training: bool,
|
|
|
73 |
) -> str:
|
|
|
|
|
74 |
if not torch.cuda.is_available():
|
75 |
raise gr.Error('CUDA is not available.')
|
76 |
if training_video is None:
|
@@ -94,7 +97,7 @@ class Trainer:
|
|
94 |
output_dir.mkdir(parents=True)
|
95 |
|
96 |
if upload_to_hub:
|
97 |
-
self.join_model_library_org()
|
98 |
|
99 |
config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
|
100 |
config.pretrained_model_path = self.download_base_model(base_model)
|
@@ -143,14 +146,16 @@ class Trainer:
|
|
143 |
repo_name=output_model_name,
|
144 |
upload_to=upload_to,
|
145 |
private=use_private_repo,
|
146 |
-
delete_existing_repo=delete_existing_repo
|
|
|
147 |
print(upload_message)
|
148 |
message = message + '\n' + upload_message
|
149 |
|
150 |
if remove_gpu_after_training:
|
151 |
space_id = os.getenv('SPACE_ID')
|
152 |
if space_id:
|
153 |
-
self.
|
154 |
-
|
|
|
155 |
|
156 |
return message
|
|
|
20 |
sys.path.append('Tune-A-Video')
|
21 |
|
22 |
URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
|
23 |
+
ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
|
24 |
+
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
|
25 |
|
26 |
class Trainer:
|
27 |
def __init__(self, hf_token: str | None = None):
|
28 |
self.hf_token = hf_token
|
|
|
29 |
self.model_uploader = ModelUploader(hf_token)
|
30 |
|
31 |
self.checkpoint_dir = pathlib.Path('checkpoints')
|
|
|
42 |
cwd=org_dir)
|
43 |
return model_dir.as_posix()
|
44 |
|
45 |
+
def join_model_library_org(self, token: str) -> None:
|
46 |
subprocess.run(
|
47 |
shlex.split(
|
48 |
+
f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
|
49 |
))
|
50 |
|
51 |
def run(
|
|
|
70 |
delete_existing_repo: bool,
|
71 |
upload_to: str,
|
72 |
remove_gpu_after_training: bool,
|
73 |
+
input_token: str,
|
74 |
) -> str:
|
75 |
+
if SPACE_ID == ORIGINAL_SPACE_ID:
|
76 |
+
raise gr.Error('This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU')
|
77 |
if not torch.cuda.is_available():
|
78 |
raise gr.Error('CUDA is not available.')
|
79 |
if training_video is None:
|
|
|
97 |
output_dir.mkdir(parents=True)
|
98 |
|
99 |
if upload_to_hub:
|
100 |
+
self.join_model_library_org(self.hf_token if self.hf_token else input_token)
|
101 |
|
102 |
config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
|
103 |
config.pretrained_model_path = self.download_base_model(base_model)
|
|
|
146 |
repo_name=output_model_name,
|
147 |
upload_to=upload_to,
|
148 |
private=use_private_repo,
|
149 |
+
delete_existing_repo=delete_existing_repo,
|
150 |
+
input_token=input_token)
|
151 |
print(upload_message)
|
152 |
message = message + '\n' + upload_message
|
153 |
|
154 |
if remove_gpu_after_training:
|
155 |
space_id = os.getenv('SPACE_ID')
|
156 |
if space_id:
|
157 |
+
api = HfApi(token=self.hf_token if self.hf_token else input_token)
|
158 |
+
api.request_space_hardware(repo_id=space_id,
|
159 |
+
hardware='cpu-basic')
|
160 |
|
161 |
return message
|
uploader.py
CHANGED
@@ -5,10 +5,7 @@ from huggingface_hub import HfApi
|
|
5 |
|
6 |
class Uploader:
|
7 |
def __init__(self, hf_token: str | None):
|
8 |
-
self.
|
9 |
-
|
10 |
-
def get_username(self) -> str:
|
11 |
-
return self.api.whoami()['name']
|
12 |
|
13 |
def upload(self,
|
14 |
folder_path: str,
|
@@ -16,25 +13,30 @@ class Uploader:
|
|
16 |
organization: str = '',
|
17 |
repo_type: str = 'model',
|
18 |
private: bool = True,
|
19 |
-
delete_existing_repo: bool = False
|
|
|
|
|
|
|
|
|
20 |
if not folder_path:
|
21 |
raise ValueError
|
22 |
if not repo_name:
|
23 |
raise ValueError
|
24 |
if not organization:
|
25 |
-
organization =
|
|
|
26 |
repo_id = f'{organization}/{repo_name}'
|
27 |
if delete_existing_repo:
|
28 |
try:
|
29 |
self.api.delete_repo(repo_id, repo_type=repo_type)
|
30 |
except Exception:
|
31 |
pass
|
32 |
-
try:
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
url = f'https://huggingface.co/{repo_id}'
|
39 |
message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
|
40 |
except Exception as e:
|
|
|
5 |
|
6 |
class Uploader:
|
7 |
def __init__(self, hf_token: str | None):
|
8 |
+
self.hf_token = hf_token
|
|
|
|
|
|
|
9 |
|
10 |
def upload(self,
|
11 |
folder_path: str,
|
|
|
13 |
organization: str = '',
|
14 |
repo_type: str = 'model',
|
15 |
private: bool = True,
|
16 |
+
delete_existing_repo: bool = False,
|
17 |
+
input_token: str | None = None) -> str:
|
18 |
+
|
19 |
+
api = HfApi(token=self.hf_token if self.hf_token else input_token)
|
20 |
+
|
21 |
if not folder_path:
|
22 |
raise ValueError
|
23 |
if not repo_name:
|
24 |
raise ValueError
|
25 |
if not organization:
|
26 |
+
organization = api.whoami()['name']
|
27 |
+
|
28 |
repo_id = f'{organization}/{repo_name}'
|
29 |
if delete_existing_repo:
|
30 |
try:
|
31 |
self.api.delete_repo(repo_id, repo_type=repo_type)
|
32 |
except Exception:
|
33 |
pass
|
34 |
+
try:
|
35 |
+
api.create_repo(repo_id, repo_type=repo_type, private=private)
|
36 |
+
api.upload_folder(repo_id=repo_id,
|
37 |
+
folder_path=folder_path,
|
38 |
+
path_in_repo='.',
|
39 |
+
repo_type=repo_type)
|
40 |
url = f'https://huggingface.co/{repo_id}'
|
41 |
message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
|
42 |
except Exception as e:
|