djuna commited on
Commit
2d36e6d
·
verified ·
1 Parent(s): 9235353

feat: Choosable CLI, Custom Output Shard Size, LORA extraction

Browse files
Files changed (2) hide show
  1. app.py +142 -29
  2. requirements.txt +1 -0
app.py CHANGED
@@ -11,6 +11,7 @@ import gradio as gr
11
  import huggingface_hub
12
  import torch
13
  import yaml
 
14
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
15
  from mergekit.config import MergeConfiguration
16
 
@@ -43,7 +44,7 @@ has_gpu = torch.cuda.is_available()
43
  # )
44
 
45
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
46
- " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
47
  )
48
 
49
  MARKDOWN_DESCRIPTION = """
@@ -111,17 +112,19 @@ examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
111
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
112
 
113
 
114
- def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
115
  runner = LogsViewRunner()
116
 
117
  if not yaml_config:
118
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
119
  return
120
- try:
121
- merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
122
- except Exception as e:
123
- yield runner.log(f"Invalid yaml {e}", level="ERROR")
124
- return
 
 
125
 
126
  is_community_model = False
127
  if not hf_token:
@@ -170,7 +173,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
170
  # Set tmp HF_HOME to avoid filling up disk Space
171
  tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
172
  tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
173
- full_cli = cli + f" --lora-merge-cache {tmpdirname}/.lora_cache"
174
  yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
175
 
176
  if runner.exit_code != 0:
@@ -187,27 +190,139 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
187
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  with gr.Blocks() as demo:
191
  gr.Markdown(MARKDOWN_DESCRIPTION)
192
 
193
- with gr.Row():
194
- filename = gr.Textbox(visible=False, label="filename")
195
- config = gr.Code(language="yaml", lines=10, label="config.yaml")
196
- with gr.Column():
197
- token = gr.Textbox(
198
- lines=1,
199
- label="HF Write Token",
200
- info="https://hf.co/settings/token",
201
- type="password",
202
- placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
203
- )
204
- repo_name = gr.Textbox(
205
- lines=1,
206
- label="Repo name",
207
- placeholder="Optional. Will create a random name if empty.",
208
- )
209
- button = gr.Button("Merge", variant="primary")
210
- logs = LogsView(label="Terminal output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  gr.Examples(
212
  examples,
213
  fn=lambda s: (s,),
@@ -218,11 +333,9 @@ with gr.Blocks() as demo:
218
  )
219
  gr.Markdown(MARKDOWN_ARTICLE)
220
 
221
- button.click(fn=merge, inputs=[config, token, repo_name], outputs=[logs])
222
-
223
 
224
  # Run garbage collection every hour to keep the community org clean.
225
- # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
226
  def _garbage_collect_every_hour():
227
  while True:
228
  try:
 
11
  import huggingface_hub
12
  import torch
13
  import yaml
14
+ import bitsandbytes
15
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
16
  from mergekit.config import MergeConfiguration
17
 
 
44
  # )
45
 
46
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
47
+ " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --lazy-unpickle"
48
  )
49
 
50
  MARKDOWN_DESCRIPTION = """
 
112
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
113
 
114
 
115
+ def merge(program: str, yaml_config: str, out_shard_size: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
116
  runner = LogsViewRunner()
117
 
118
  if not yaml_config:
119
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
120
  return
121
+ # TODO: validate moe config and mega config?
122
+ if program not in ("mergekit-moe", "mergekit-mega"):
123
+ try:
124
+ merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
125
+ except Exception as e:
126
+ yield runner.log(f"Invalid yaml {e}", level="ERROR")
127
+ return
128
 
129
  is_community_model = False
130
  if not hf_token:
 
173
  # Set tmp HF_HOME to avoid filling up disk Space
174
  tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
175
  tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
176
+ full_cli = f"{program} {cli} --lora-merge-cache {tmpdirname}/.lora_cache --out-shard-size {out_shard_size}"
177
  yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
178
 
179
  if runner.exit_code != 0:
 
190
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
191
 
192
 
193
+ def extract(finetuned_model: str, base_model: str, rank: int, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
194
+ runner = LogsViewRunner()
195
+ if not finetuned_model or not base_model:
196
+ yield runner.log("All field should be filled")
197
+
198
+ is_community_model = False
199
+ if not hf_token:
200
+ if "/" in repo_name and not repo_name.startswith("mergekit-community/"):
201
+ yield runner.log(
202
+ f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
203
+ level="ERROR",
204
+ )
205
+ return
206
+ yield runner.log(
207
+ "No HF token provided. Your lora will be uploaded to the https://huggingface.co/mergekit-community organization."
208
+ )
209
+ is_community_model = True
210
+ if not COMMUNITY_HF_TOKEN:
211
+ raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
212
+ hf_token = COMMUNITY_HF_TOKEN
213
+
214
+ api = huggingface_hub.HfApi(token=hf_token)
215
+
216
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
217
+ tmpdir = pathlib.Path(tmpdirname)
218
+ merged_path = tmpdir / "merged"
219
+ merged_path.mkdir(parents=True, exist_ok=True)
220
+
221
+ if not repo_name:
222
+ yield runner.log("No repo name provided. Generating a random one.")
223
+ repo_name = "lora"
224
+ # Make repo_name "unique" (no need to be extra careful on uniqueness)
225
+ repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
226
+ repo_name = repo_name.replace("/", "-").strip("-")
227
+
228
+ if is_community_model and not repo_name.startswith("mergekit-community/"):
229
+ repo_name = f"mergekit-community/{repo_name}"
230
+
231
+ try:
232
+ yield runner.log(f"Creating repo {repo_name}")
233
+ repo_url = api.create_repo(repo_name, exist_ok=True)
234
+ yield runner.log(f"Repo created: {repo_url}")
235
+ except Exception as e:
236
+ yield runner.log(f"Error creating repo {e}", level="ERROR")
237
+ return
238
+
239
+ # Set tmp HF_HOME to avoid filling up disk Space
240
+ tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
241
+ tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
242
+ full_cli = f"mergekit-extract-lora {finetuned_model} {base_model} lora --rank={rank}"
243
+ yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
244
+
245
+ if runner.exit_code != 0:
246
+ yield runner.log("Lora extraction failed. Deleting repo as no lora is uploaded.", level="ERROR")
247
+ api.delete_repo(repo_url.repo_id)
248
+ return
249
+
250
+ yield runner.log("Lora extracted successfully. Uploading to HF.")
251
+ yield from runner.run_python(
252
+ api.upload_folder,
253
+ repo_id=repo_url.repo_id,
254
+ folder_path=merged_path / "lora",
255
+ )
256
+ yield runner.log(f"Lora successfully uploaded to HF: {repo_url.repo_id}")
257
+
258
+
259
  with gr.Blocks() as demo:
260
  gr.Markdown(MARKDOWN_DESCRIPTION)
261
 
262
+ with gr.Tabs():
263
+ with gr.TabItem("Merge Model"):
264
+ with gr.Row():
265
+ filename = gr.Textbox(visible=False, label="filename")
266
+ config = gr.Code(language="yaml", lines=10, label="config.yaml")
267
+ with gr.Column():
268
+ program = gr.Dropdown(
269
+ ["mergekit-yaml", "mergekit-mega", "mergekit-moe"],
270
+ label="Mergekit Command",
271
+ info="Choose CLI",
272
+ )
273
+ out_shard_size = gr.Dropdown(
274
+ ["500M", "1B", "2B", "3B", "4B", "5B"],
275
+ label="Output Shard Size",
276
+ value="500M",
277
+ )
278
+ token = gr.Textbox(
279
+ lines=1,
280
+ label="HF Write Token",
281
+ info="https://hf.co/settings/token",
282
+ type="password",
283
+ placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
284
+ )
285
+ repo_name = gr.Textbox(
286
+ lines=1,
287
+ label="Repo name",
288
+ placeholder="Optional. Will create a random name if empty.",
289
+ )
290
+ button = gr.Button("Merge", variant="primary")
291
+ logs = LogsView(label="Terminal output")
292
+ button.click(fn=merge, inputs=[program, config, out_shard_size, token, repo_name], outputs=[logs])
293
+
294
+ with gr.TabItem("LORA Extraction"):
295
+ with gr.Row():
296
+ with gr.Column():
297
+ finetuned_model = gr.Textbox(
298
+ lines=1,
299
+ label="Finetuned Model",
300
+ )
301
+ base_model = gr.Textbox(
302
+ lines=1,
303
+ label="Base Model",
304
+ )
305
+ rank = gr.Dropdown(
306
+ [32, 64, 128],
307
+ label="Rank level",
308
+ value=32,
309
+ )
310
+ with gr.Column():
311
+ token = gr.Textbox(
312
+ lines=1,
313
+ label="HF Write Token",
314
+ info="https://hf.co/settings/token",
315
+ type="password",
316
+ placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
317
+ )
318
+ repo_name = gr.Textbox(
319
+ lines=1,
320
+ label="Repo name",
321
+ placeholder="Optional. Will create a random name if empty.",
322
+ )
323
+ button = gr.Button("Extract LORA", variant="primary")
324
+ logs = LogsView(label="Terminal output")
325
+ button.click(fn=extract, inputs=[finetuned_model, base_model, rank, token, repo_name], outputs=[logs])
326
  gr.Examples(
327
  examples,
328
  fn=lambda s: (s,),
 
333
  )
334
  gr.Markdown(MARKDOWN_ARTICLE)
335
 
 
 
336
 
337
  # Run garbage collection every hour to keep the community org clean.
338
+ # Empty models might exist if the merge fails abruptly (e.g. if user leaves the Space).
339
  def _garbage_collect_every_hour():
340
  while True:
341
  try:
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
 
2
  git+https://github.com/arcee-ai/mergekit.git
3
  # see https://huggingface.co/spaces/Wauplin/gradio_logsview
4
  gradio_logsview@https://huggingface.co/spaces/Wauplin/gradio_logsview/resolve/main/gradio_logsview-0.0.5-py3-none-any.whl
 
1
  torch
2
+ bitsandbytes
3
  git+https://github.com/arcee-ai/mergekit.git
4
  # see https://huggingface.co/spaces/Wauplin/gradio_logsview
5
  gradio_logsview@https://huggingface.co/spaces/Wauplin/gradio_logsview/resolve/main/gradio_logsview-0.0.5-py3-none-any.whl