John6666 commited on
Commit
50bf1df
·
verified ·
1 Parent(s): a6a09d9

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +68 -46
utils.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  import urllib.parse
9
  import subprocess
10
  import time
 
11
 
12
 
13
  def get_token():
@@ -25,6 +26,17 @@ def set_token(token):
25
  print(f"Error: Failed to save token.")
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  def get_user_agent():
29
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
30
 
@@ -113,31 +125,33 @@ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
113
 
114
 
115
  def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
116
- url = url.strip()
117
- if "drive.google.com" in url:
118
- original_dir = os.getcwd()
119
- os.chdir(directory)
120
- os.system(f"gdown --fuzzy {url}")
121
- os.chdir(original_dir)
122
- elif "huggingface.co" in url:
123
- url = url.replace("?download=true", "")
124
- if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
125
- download_hf_file(directory, url)
126
- elif "civitai.com" in url:
127
- if "?" in url:
128
- url = url.split("?")[0]
129
- if civitai_api_key:
130
- url = url + f"?token={civitai_api_key}"
131
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
 
 
132
  else:
133
- print("You need an API key to download Civitai models.")
134
- else:
135
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
136
 
137
 
138
  def get_local_file_list(dir_path):
139
  file_list = []
140
- for file in Path(dir_path).glob("**/*.*"):
141
  if file.is_file():
142
  file_path = str(file)
143
  file_list.append(file_path)
@@ -145,30 +159,30 @@ def get_local_file_list(dir_path):
145
 
146
 
147
  def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
148
- if not "http" in url and is_repo_name(url) and not Path(url).exists():
149
- print(f"Use HF Repo: {url}")
150
- new_file = url
151
- elif not "http" in url and Path(url).exists():
152
- print(f"Use local file: {url}")
153
- new_file = url
154
- elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
155
- print(f"File to download alreday exists: {url}")
156
- new_file = f"{temp_dir}/{url.split('/')[-1]}"
157
- else:
158
- print(f"Start downloading: {url}")
159
- before = get_local_file_list(temp_dir)
160
- try:
161
  download_thing(temp_dir, url.strip(), civitai_key)
162
- except Exception:
 
 
163
  print(f"Download failed: {url}")
164
  return ""
165
- after = get_local_file_list(temp_dir)
166
- new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
167
- if not new_file:
168
- print(f"Download failed: {url}")
169
  return ""
170
- print(f"Download completed: {url}")
171
- return new_file
172
 
173
 
174
  def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
@@ -183,17 +197,13 @@ def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=T
183
  return False
184
 
185
 
186
- def upload_repo(repo_id: str, dir_path: str, is_private: bool, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
187
  hf_token = get_token()
188
  api = HfApi(token=hf_token)
189
  try:
190
  progress(0, desc="Start uploading...")
191
  api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
192
- for path in Path(dir_path).glob("*"):
193
- if path.is_dir():
194
- api.upload_folder(repo_id=repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
195
- elif path.is_file():
196
- api.upload_file(repo_id=repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
197
  progress(1, desc="Uploaded.")
198
  return get_hf_url(repo_id, "model")
199
  except Exception as e:
@@ -201,6 +211,18 @@ def upload_repo(repo_id: str, dir_path: str, is_private: bool, progress=gr.Progr
201
  return ""
202
 
203
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  HF_SUBFOLDER_NAME = ["None", "user_repo"]
205
 
206
 
 
8
  import urllib.parse
9
  import subprocess
10
  import time
11
+ from typing import Any
12
 
13
 
14
  def get_token():
 
26
  print(f"Error: Failed to save token.")
27
 
28
 
29
+ def get_state(state: dict, key: str):
30
+ if key in state.keys(): return state[key]
31
+ else:
32
+ print(f"State '{key}' not found.")
33
+ return None
34
+
35
+
36
+ def set_state(state: dict, key: str, value: Any):
37
+ state[key] = value
38
+
39
+
40
  def get_user_agent():
41
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
42
 
 
125
 
126
 
127
  def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
128
+ try:
129
+ url = url.strip()
130
+ if "drive.google.com" in url:
131
+ original_dir = os.getcwd()
132
+ os.chdir(directory)
133
+ subprocess.run(f"gdown --fuzzy {url}", shell=True)
134
+ os.chdir(original_dir)
135
+ elif "huggingface.co" in url:
136
+ url = url.replace("?download=true", "")
137
+ if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
138
+ download_hf_file(directory, url)
139
+ elif "civitai.com" in url:
140
+ if civitai_api_key:
141
+ url = f"'{url}&token={civitai_api_key}'" if "?" in url else f"{url}?token={civitai_api_key}"
142
+ print(f"Downloading {url}")
143
+ subprocess.run(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}", shell=True)
144
+ else:
145
+ print("You need an API key to download Civitai models.")
146
  else:
147
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
148
+ except Exception as e:
149
+ print(f"Failed to download: {e}")
150
 
151
 
152
  def get_local_file_list(dir_path):
153
  file_list = []
154
+ for file in Path(dir_path).glob("*/*.*"):
155
  if file.is_file():
156
  file_path = str(file)
157
  file_list.append(file_path)
 
159
 
160
 
161
  def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
162
+ try:
163
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
164
+ print(f"Use HF Repo: {url}")
165
+ new_file = url
166
+ elif not "http" in url and Path(url).exists():
167
+ print(f"Use local file: {url}")
168
+ new_file = url
169
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
170
+ print(f"File to download alreday exists: {url}")
171
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
172
+ else:
173
+ print(f"Start downloading: {url}")
174
+ before = get_local_file_list(temp_dir)
175
  download_thing(temp_dir, url.strip(), civitai_key)
176
+ after = get_local_file_list(temp_dir)
177
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
178
+ if not new_file:
179
  print(f"Download failed: {url}")
180
  return ""
181
+ print(f"Download completed: {url}")
182
+ return new_file
183
+ except Exception as e:
184
+ print(f"Download failed: {url} {e}")
185
  return ""
 
 
186
 
187
 
188
  def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
 
197
  return False
198
 
199
 
200
+ def upload_repo(repo_id: str, dir_path: str, is_private: bool, is_pr: bool=False, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
201
  hf_token = get_token()
202
  api = HfApi(token=hf_token)
203
  try:
204
  progress(0, desc="Start uploading...")
205
  api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
206
+ api.upload_folder(repo_id=repo_id, folder_path=dir_path, path_in_repo="", create_pr=is_pr, token=hf_token)
 
 
 
 
207
  progress(1, desc="Uploaded.")
208
  return get_hf_url(repo_id, "model")
209
  except Exception as e:
 
211
  return ""
212
 
213
 
214
+ def gate_repo(repo_id: str, gated_str: str, repo_type: str="model"):
215
+ hf_token = get_token()
216
+ api = HfApi(token=hf_token)
217
+ try:
218
+ if gated_str == "auto": gated = "auto"
219
+ elif gated_str == "manual": gated = "manual"
220
+ else: gated = False
221
+ api.update_repo_settings(repo_id=repo_id, gated=gated, repo_type=repo_type, token=hf_token)
222
+ except Exception as e:
223
+ print(f"Error: Failed to update settings {repo_id}. {e}")
224
+
225
+
226
  HF_SUBFOLDER_NAME = ["None", "user_repo"]
227
 
228