hysts HF staff multimodalart HF staff commited on
Commit
a3f0757
1 Parent(s): cadb289

simplified-ui (#3)

Browse files

- Simplified UI (17a686c02af144dbc69aabd538b946bb34cb3c4e)


Co-authored-by: Multimodal AI art <[email protected]>

Files changed (5) hide show
  1. app.py +21 -12
  2. app_training.py +64 -64
  3. app_upload.py +5 -1
  4. trainer.py +13 -8
  5. uploader.py +14 -12
app.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
 
7
  import gradio as gr
8
  import torch
@@ -13,27 +14,31 @@ from app_upload import create_upload_demo
13
  from inference import InferencePipeline
14
  from trainer import Trainer
15
 
16
- TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) Training UI'
17
 
18
  ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
19
  SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
- SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU. (Please note that there seems to be an issue with training on the A10G GPU now. The model doesn't learn anything when trained on A10G. Training on T4 works perfectly fine and inference works fine on both.)
 
21
 
22
- <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
23
  '''
24
 
25
  if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
26
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
27
  else:
28
  SETTINGS = 'Settings'
29
- CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
 
 
 
30
  <center>
31
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
32
- You can use "T4 small/medium" or "A10G small/large" to run this demo.
33
  </center>
34
  '''
35
 
36
- HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
37
  <center>
38
  You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
39
  You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
@@ -54,23 +59,27 @@ pipe = InferencePipeline(HF_TOKEN)
54
  trainer = Trainer(HF_TOKEN)
55
 
56
  with gr.Blocks(css='style.css') as demo:
57
- if os.getenv('IS_SHARED_UI'):
58
  show_warning(SHARED_UI_WARNING)
59
- if not torch.cuda.is_available():
60
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
61
- if not HF_TOKEN:
62
- show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
 
63
 
64
  gr.Markdown(TITLE)
65
  with gr.Tabs():
66
  with gr.TabItem('Train'):
67
  create_training_demo(trainer, pipe)
68
- with gr.TabItem('Test'):
69
  create_inference_demo(pipe, HF_TOKEN)
70
  with gr.TabItem('Upload'):
71
  gr.Markdown('''
72
  - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
73
  ''')
74
  create_upload_demo(HF_TOKEN)
 
 
 
75
 
76
- demo.queue(max_size=1).launch(share=False)
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ from subprocess import getoutput
7
 
8
  import gradio as gr
9
  import torch
 
14
  from inference import InferencePipeline
15
  from trainer import Trainer
16
 
17
+ TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) UI'
18
 
19
  ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
20
  SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
21
+ GPU_DATA = getoutput('nvidia-smi')
22
+ SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
23
 
24
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
25
  '''
26
 
27
  if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
28
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
29
  else:
30
  SETTINGS = 'Settings'
31
+
32
+ INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
33
+
34
+ CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
35
  <center>
36
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
37
+ You can use "T4 small/medium" to run this demo.
38
  </center>
39
  '''
40
 
41
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
42
  <center>
43
  You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
44
  You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
 
59
  trainer = Trainer(HF_TOKEN)
60
 
61
  with gr.Blocks(css='style.css') as demo:
62
+ if SPACE_ID == ORIGINAL_SPACE_ID:
63
  show_warning(SHARED_UI_WARNING)
64
+ elif not torch.cuda.is_available():
65
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
+ elif(not "T4" in GPU_DATA):
67
+ show_warning(INVALID_GPU_WARNING)
68
+
69
 
70
  gr.Markdown(TITLE)
71
  with gr.Tabs():
72
  with gr.TabItem('Train'):
73
  create_training_demo(trainer, pipe)
74
+ with gr.TabItem('Run'):
75
  create_inference_demo(pipe, HF_TOKEN)
76
  with gr.TabItem('Upload'):
77
  gr.Markdown('''
78
  - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
79
  ''')
80
  create_upload_demo(HF_TOKEN)
81
+
82
+ if not HF_TOKEN:
83
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
84
 
85
+ demo.queue(max_size=1).launch(share=False)
app_training.py CHANGED
@@ -13,6 +13,7 @@ from trainer import Trainer
13
 
14
  def create_training_demo(trainer: Trainer,
15
  pipe: InferencePipeline | None = None) -> gr.Blocks:
 
16
  with gr.Blocks() as demo:
17
  with gr.Row():
18
  with gr.Column():
@@ -24,82 +25,80 @@ def create_training_demo(trainer: Trainer,
24
  max_lines=1,
25
  placeholder='A man is surfing')
26
  gr.Markdown('''
27
- - Upload a video and write a prompt describing the video.
28
  ''')
 
 
29
  with gr.Box():
30
- gr.Markdown('Output Model')
31
- output_model_name = gr.Text(label='Name of your model',
32
- max_lines=1)
33
- delete_existing_model = gr.Checkbox(
34
- label='Delete existing model of the same name',
35
- value=False)
36
- validation_prompt = gr.Text(label='Validation Prompt')
37
- with gr.Box():
38
- gr.Markdown('Upload Settings')
39
  with gr.Row():
40
- upload_to_hub = gr.Checkbox(
41
- label='Upload model to Hub', value=True)
42
- use_private_repo = gr.Checkbox(label='Private',
43
- value=True)
44
- delete_existing_repo = gr.Checkbox(
45
- label='Delete existing repo of the same name',
46
- value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  upload_to = gr.Radio(
48
  label='Upload to',
49
  choices=[_.value for _ in UploadTarget],
50
  value=UploadTarget.MODEL_LIBRARY.value)
51
- gr.Markdown(f'''
52
- - By default, trained models will be uploaded to [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (see [this example model](https://huggingface.co/{SAMPLE_MODEL_REPO})).
53
- - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{{your_username}}/{{model_name}}.
54
- ''')
55
-
56
- with gr.Box():
57
- gr.Markdown('Training Parameters')
58
- with gr.Row():
59
- base_model = gr.Text(label='Base Model',
60
- value='CompVis/stable-diffusion-v1-4',
61
- max_lines=1)
62
- resolution = gr.Dropdown(choices=['512', '768'],
63
- value='512',
64
- label='Resolution',
65
- visible=False)
66
- num_training_steps = gr.Number(
67
- label='Number of Training Steps', value=300, precision=0)
68
- learning_rate = gr.Number(label='Learning Rate',
69
- value=0.000035)
70
- gradient_accumulation = gr.Number(
71
- label='Number of Gradient Accumulation',
72
- value=1,
73
- precision=0)
74
- seed = gr.Slider(label='Seed',
75
- minimum=0,
76
- maximum=100000,
77
- step=1,
78
- value=0)
79
- fp16 = gr.Checkbox(label='FP16', value=True)
80
- use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
81
- checkpointing_steps = gr.Number(label='Checkpointing Steps',
82
- value=1000,
83
- precision=0)
84
- validation_epochs = gr.Number(label='Validation Epochs',
85
- value=100,
86
- precision=0)
87
- gr.Markdown('''
88
- - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
89
- - It takes a few minutes to download the base model first.
90
- - Expected time to train a model for 300 steps: 20 minutes with T4, 8 minutes with A10G, (4 minutes with A100)
91
- - It takes a few minutes to upload your trained model.
92
- - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
93
- - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
94
- ''')
95
-
96
  remove_gpu_after_training = gr.Checkbox(
97
  label='Remove GPU after training',
98
  value=False,
99
  interactive=bool(os.getenv('SPACE_ID')),
100
  visible=False)
101
  run_button = gr.Button('Start Training')
102
-
103
  with gr.Box():
104
  gr.Markdown('Output message')
105
  output_message = gr.Markdown()
@@ -111,7 +110,7 @@ def create_training_demo(trainer: Trainer,
111
  training_video,
112
  training_prompt,
113
  output_model_name,
114
- delete_existing_model,
115
  validation_prompt,
116
  base_model,
117
  resolution,
@@ -128,6 +127,7 @@ def create_training_demo(trainer: Trainer,
128
  delete_existing_repo,
129
  upload_to,
130
  remove_gpu_after_training,
 
131
  ],
132
  outputs=output_message)
133
  return demo
 
13
 
14
  def create_training_demo(trainer: Trainer,
15
  pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ hf_token = os.getenv('HF_TOKEN')
17
  with gr.Blocks() as demo:
18
  with gr.Row():
19
  with gr.Column():
 
25
  max_lines=1,
26
  placeholder='A man is surfing')
27
  gr.Markdown('''
28
+ - Upload a video and write a `Training Prompt` that describes the video.
29
  ''')
30
+
31
+ with gr.Column():
32
  with gr.Box():
33
+ gr.Markdown('Training Parameters')
 
 
 
 
 
 
 
 
34
  with gr.Row():
35
+ base_model = gr.Text(label='Base Model',
36
+ value='CompVis/stable-diffusion-v1-4',
37
+ max_lines=1)
38
+ resolution = gr.Dropdown(choices=['512', '768'],
39
+ value='512',
40
+ label='Resolution',
41
+ visible=False)
42
+
43
+ input_token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=False if hf_token else True)
44
+ with gr.Accordion("Advanced settings", open=False):
45
+ num_training_steps = gr.Number(
46
+ label='Number of Training Steps', value=300, precision=0)
47
+ learning_rate = gr.Number(label='Learning Rate',
48
+ value=0.000035)
49
+ gradient_accumulation = gr.Number(
50
+ label='Number of Gradient Accumulation',
51
+ value=1,
52
+ precision=0)
53
+ seed = gr.Slider(label='Seed',
54
+ minimum=0,
55
+ maximum=100000,
56
+ step=1,
57
+ randomize=True,
58
+ value=0)
59
+ fp16 = gr.Checkbox(label='FP16', value=True)
60
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
61
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
62
+ value=1000,
63
+ precision=0)
64
+ validation_epochs = gr.Number(label='Validation Epochs',
65
+ value=100,
66
+ precision=0)
67
+ gr.Markdown('''
68
+ - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
69
+ - Expected time to train a model for 300 steps: ~20 minutes with T4
70
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
71
+ ''')
72
+
73
+ with gr.Row():
74
+ with gr.Column():
75
+ gr.Markdown('Output Model')
76
+ output_model_name = gr.Text(label='Name of your model',
77
+ placeholder='The surfer man',
78
+ max_lines=1)
79
+ validation_prompt = gr.Text(label='Validation Prompt', placeholder='prompt to test the model, e.g: a dog is surfing')
80
+ with gr.Column():
81
+ gr.Markdown('Upload Settings')
82
+ with gr.Row():
83
+ upload_to_hub = gr.Checkbox(
84
+ label='Upload model to Hub', value=True)
85
+ use_private_repo = gr.Checkbox(label='Private',
86
+ value=True)
87
+ delete_existing_repo = gr.Checkbox(
88
+ label='Delete existing repo of the same name',
89
+ value=False)
90
  upload_to = gr.Radio(
91
  label='Upload to',
92
  choices=[_.value for _ in UploadTarget],
93
  value=UploadTarget.MODEL_LIBRARY.value)
94
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  remove_gpu_after_training = gr.Checkbox(
96
  label='Remove GPU after training',
97
  value=False,
98
  interactive=bool(os.getenv('SPACE_ID')),
99
  visible=False)
100
  run_button = gr.Button('Start Training')
101
+
102
  with gr.Box():
103
  gr.Markdown('Output message')
104
  output_message = gr.Markdown()
 
110
  training_video,
111
  training_prompt,
112
  output_model_name,
113
+ delete_existing_repo,
114
  validation_prompt,
115
  base_model,
116
  resolution,
 
127
  delete_existing_repo,
128
  upload_to,
129
  remove_gpu_after_training,
130
+ input_token
131
  ],
132
  outputs=output_message)
133
  return demo
app_upload.py CHANGED
@@ -20,6 +20,7 @@ class ModelUploader(Uploader):
20
  upload_to: str,
21
  private: bool,
22
  delete_existing_repo: bool,
 
23
  ) -> str:
24
  if not folder_path:
25
  raise ValueError
@@ -38,7 +39,8 @@ class ModelUploader(Uploader):
38
  repo_name,
39
  organization=organization,
40
  private=private,
41
- delete_existing_repo=delete_existing_repo)
 
42
 
43
 
44
  def load_local_model_list() -> dict:
@@ -68,6 +70,7 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
68
  choices=[_.value for _ in UploadTarget],
69
  value=UploadTarget.MODEL_LIBRARY.value)
70
  model_name = gr.Textbox(label='Model Name')
 
71
  upload_button = gr.Button('Upload')
72
  gr.Markdown(f'''
73
  - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
@@ -86,6 +89,7 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
86
  upload_to,
87
  use_private_repo,
88
  delete_existing_repo,
 
89
  ],
90
  outputs=output_message)
91
 
 
20
  upload_to: str,
21
  private: bool,
22
  delete_existing_repo: bool,
23
+ input_token: str | None = None,
24
  ) -> str:
25
  if not folder_path:
26
  raise ValueError
 
39
  repo_name,
40
  organization=organization,
41
  private=private,
42
+ delete_existing_repo=delete_existing_repo,
43
+ input_token=input_token)
44
 
45
 
46
  def load_local_model_list() -> dict:
 
70
  choices=[_.value for _ in UploadTarget],
71
  value=UploadTarget.MODEL_LIBRARY.value)
72
  model_name = gr.Textbox(label='Model Name')
73
+ input_token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=False if hf_token else True)
74
  upload_button = gr.Button('Upload')
75
  gr.Markdown(f'''
76
  - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
 
89
  upload_to,
90
  use_private_repo,
91
  delete_existing_repo,
92
+ input_token,
93
  ],
94
  outputs=output_message)
95
 
trainer.py CHANGED
@@ -20,12 +20,12 @@ from utils import save_model_card
20
  sys.path.append('Tune-A-Video')
21
 
22
  URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
-
 
24
 
25
  class Trainer:
26
  def __init__(self, hf_token: str | None = None):
27
  self.hf_token = hf_token
28
- self.api = HfApi(token=hf_token)
29
  self.model_uploader = ModelUploader(hf_token)
30
 
31
  self.checkpoint_dir = pathlib.Path('checkpoints')
@@ -42,10 +42,10 @@ class Trainer:
42
  cwd=org_dir)
43
  return model_dir.as_posix()
44
 
45
- def join_model_library_org(self) -> None:
46
  subprocess.run(
47
  shlex.split(
48
- f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
49
  ))
50
 
51
  def run(
@@ -70,7 +70,10 @@ class Trainer:
70
  delete_existing_repo: bool,
71
  upload_to: str,
72
  remove_gpu_after_training: bool,
 
73
  ) -> str:
 
 
74
  if not torch.cuda.is_available():
75
  raise gr.Error('CUDA is not available.')
76
  if training_video is None:
@@ -94,7 +97,7 @@ class Trainer:
94
  output_dir.mkdir(parents=True)
95
 
96
  if upload_to_hub:
97
- self.join_model_library_org()
98
 
99
  config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
100
  config.pretrained_model_path = self.download_base_model(base_model)
@@ -143,14 +146,16 @@ class Trainer:
143
  repo_name=output_model_name,
144
  upload_to=upload_to,
145
  private=use_private_repo,
146
- delete_existing_repo=delete_existing_repo)
 
147
  print(upload_message)
148
  message = message + '\n' + upload_message
149
 
150
  if remove_gpu_after_training:
151
  space_id = os.getenv('SPACE_ID')
152
  if space_id:
153
- self.api.request_space_hardware(repo_id=space_id,
154
- hardware='cpu-basic')
 
155
 
156
  return message
 
20
  sys.path.append('Tune-A-Video')
21
 
22
  URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
24
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
25
 
26
  class Trainer:
27
  def __init__(self, hf_token: str | None = None):
28
  self.hf_token = hf_token
 
29
  self.model_uploader = ModelUploader(hf_token)
30
 
31
  self.checkpoint_dir = pathlib.Path('checkpoints')
 
42
  cwd=org_dir)
43
  return model_dir.as_posix()
44
 
45
+ def join_model_library_org(self, token: str) -> None:
46
  subprocess.run(
47
  shlex.split(
48
+ f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
49
  ))
50
 
51
  def run(
 
70
  delete_existing_repo: bool,
71
  upload_to: str,
72
  remove_gpu_after_training: bool,
73
+ input_token: str,
74
  ) -> str:
75
+ if SPACE_ID == ORIGINAL_SPACE_ID:
76
+ raise gr.Error('This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU')
77
  if not torch.cuda.is_available():
78
  raise gr.Error('CUDA is not available.')
79
  if training_video is None:
 
97
  output_dir.mkdir(parents=True)
98
 
99
  if upload_to_hub:
100
+ self.join_model_library_org(self.hf_token if self.hf_token else input_token)
101
 
102
  config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
103
  config.pretrained_model_path = self.download_base_model(base_model)
 
146
  repo_name=output_model_name,
147
  upload_to=upload_to,
148
  private=use_private_repo,
149
+ delete_existing_repo=delete_existing_repo,
150
+ input_token=input_token)
151
  print(upload_message)
152
  message = message + '\n' + upload_message
153
 
154
  if remove_gpu_after_training:
155
  space_id = os.getenv('SPACE_ID')
156
  if space_id:
157
+ api = HfApi(token=self.hf_token if self.hf_token else input_token)
158
+ api.request_space_hardware(repo_id=space_id,
159
+ hardware='cpu-basic')
160
 
161
  return message
uploader.py CHANGED
@@ -5,10 +5,7 @@ from huggingface_hub import HfApi
5
 
6
  class Uploader:
7
  def __init__(self, hf_token: str | None):
8
- self.api = HfApi(token=hf_token)
9
-
10
- def get_username(self) -> str:
11
- return self.api.whoami()['name']
12
 
13
  def upload(self,
14
  folder_path: str,
@@ -16,25 +13,30 @@ class Uploader:
16
  organization: str = '',
17
  repo_type: str = 'model',
18
  private: bool = True,
19
- delete_existing_repo: bool = False) -> str:
 
 
 
 
20
  if not folder_path:
21
  raise ValueError
22
  if not repo_name:
23
  raise ValueError
24
  if not organization:
25
- organization = self.get_username()
 
26
  repo_id = f'{organization}/{repo_name}'
27
  if delete_existing_repo:
28
  try:
29
  self.api.delete_repo(repo_id, repo_type=repo_type)
30
  except Exception:
31
  pass
32
- try:
33
- self.api.create_repo(repo_id, repo_type=repo_type, private=private)
34
- self.api.upload_folder(repo_id=repo_id,
35
- folder_path=folder_path,
36
- path_in_repo='.',
37
- repo_type=repo_type)
38
  url = f'https://huggingface.co/{repo_id}'
39
  message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
40
  except Exception as e:
 
5
 
6
  class Uploader:
7
  def __init__(self, hf_token: str | None):
8
+ self.hf_token = hf_token
 
 
 
9
 
10
  def upload(self,
11
  folder_path: str,
 
13
  organization: str = '',
14
  repo_type: str = 'model',
15
  private: bool = True,
16
+ delete_existing_repo: bool = False,
17
+ input_token: str | None = None) -> str:
18
+
19
+ api = HfApi(token=self.hf_token if self.hf_token else input_token)
20
+
21
  if not folder_path:
22
  raise ValueError
23
  if not repo_name:
24
  raise ValueError
25
  if not organization:
26
+ organization = api.whoami()['name']
27
+
28
  repo_id = f'{organization}/{repo_name}'
29
  if delete_existing_repo:
30
  try:
31
  self.api.delete_repo(repo_id, repo_type=repo_type)
32
  except Exception:
33
  pass
34
+ try:
35
+ api.create_repo(repo_id, repo_type=repo_type, private=private)
36
+ api.upload_folder(repo_id=repo_id,
37
+ folder_path=folder_path,
38
+ path_in_repo='.',
39
+ repo_type=repo_type)
40
  url = f'https://huggingface.co/{repo_id}'
41
  message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
42
  except Exception as e: