Spaces:
Running
Running
File size: 4,937 Bytes
08e5ef1 7edda8b 2bede7c 4c4c78d 5fd1a0a 7edda8b 2bede7c 75b770e 08e5ef1 aa85862 ac97e5b 08e5ef1 1fba392 925d15e 08e5ef1 2bede7c 859a065 ac97e5b eefa44d 925d15e 7686e09 aa85862 ac97e5b 859a065 ac97e5b 892a74e aa85862 7c36326 aa85862 5696fee eefa44d aa85862 ac97e5b eefa44d 9781999 eefa44d ac97e5b aa85862 9781999 5696fee 9781999 eefa44d aa85862 9781999 00dc59f 2bede7c 00dc59f 098f871 ec000c3 3ad22ce 4c4c78d 3ad22ce 4c4c78d 098f871 3ad22ce 098f871 4c4c78d 892a74e 3ad22ce 4c4c78d 3ad22ce 098f871 c360795 3ad22ce 2bede7c 925d15e 098f871 925d15e b31944c 925d15e 2bede7c c360795 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import shutil
import subprocess
import signal
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
import gradio as gr
from huggingface_hub import create_repo, HfApi
from huggingface_hub import snapshot_download
from huggingface_hub import whoami
from huggingface_hub import ModelCard
from huggingface_hub import login
from huggingface_hub import scan_cache_dir
from huggingface_hub import logging
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from apscheduler.schedulers.background import BackgroundScheduler
from textwrap import dedent
import mlx_lm
from mlx_lm import convert
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
HF_TOKEN = os.environ.get("HF_TOKEN")
def clear_cache():
scan = scan_cache_dir()
to_delete = []
for repo in scan.repos:
if repo.repo_type == "model":
to_delete.append([rev.commit_hash for rev in repo.revisions])
scan.delete_revisions(*to_delete)
print("Cache has been cleared")
def upload_to_hub(path, upload_repo, hf_path, token):
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{mlx_lm.__version__}**.
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
"""
)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def process_model(model_id, q_method,oauth_token: gr.OAuthToken | None):
if oauth_token.token is None:
raise ValueError("You must be logged in to use MLX-my-repo")
model_name = model_id.split('/')[-1]
username = whoami(oauth_token.token)["name"]
# login(token=oauth_token.token, add_to_git_credential=True)
try:
upload_repo = username + "/" + model_name + "-mlx"
convert(model_id, quantize=True)
upload_repo(path="mlx_model", upload_repo=upload_repo, hf_path=repo_id, token=oauth_token.token)
clear_cache()
return (
f'Find your repo <a href=\'{new_repo_url}\' target="_blank" style="text-decoration:underline">here</a>',
"llama.png",
)
except Exception as e:
return (f"Error: {e}", "error.png")
finally:
shutil.rmtree("mlx_model", ignore_errors=True)
clear_cache()
print("Folder cleaned up successfully!")
css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
"""
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown("You must be logged in to use MLX-my-repo.")
gr.LoginButton(min_width=250)
model_id = HuggingfaceHubSearch(
label="Hub Model ID",
placeholder="Search for model id on Huggingface",
search_type="model",
)
q_method = gr.Dropdown(
["Q4", "Q8"],
label="Quantization Method",
info="MLX quantization type",
value="Q4",
filterable=False,
visible=True
)
iface = gr.Interface(
fn=process_model,
inputs=[
model_id,
q_method,
],
outputs=[
gr.Markdown(label="output"),
gr.Image(show_label=False),
],
title="Create your own MLX Quants, blazingly fast ⚡!",
description="The space takes an HF repo as an input, quantizes it and creates a Public/ Private repo containing the selected quant under your HF user namespace.",
api_name=False
)
def restart_space():
HfApi().restart_space(repo_id="reach-vb/mlx-my-repo", token=HF_TOKEN, factory_reboot=True)
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=21600)
scheduler.start()
# Launch the interface
demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False) |