Spaces:
Running
Running
File size: 6,150 Bytes
08e5ef1 173d502 5fd1a0a 7edda8b 173d502 75b770e 08e5ef1 aa85862 ac97e5b 08e5ef1 1fba392 925d15e 08e5ef1 2bede7c c613bb1 ac97e5b 925d15e 512570f 7686e09 173d502 ea0a3af 173d502 12b0af6 a5b3080 12b0af6 a5b3080 12b0af6 5b4e988 aa85862 173d502 aa85862 a5b3080 345fe11 ac97e5b c613bb1 ac97e5b a5b3080 ac97e5b 12b0af6 79202e2 12b0af6 7e7dbaf 4cd057e a5b3080 ac97e5b 173d502 9390868 63bdf0d a5b3080 63bdf0d a5b3080 9781999 a5b3080 d1518f3 a5b3080 d1518f3 a5b3080 d1518f3 a5b3080 9781999 173d502 9781999 63bdf0d 9781999 9390868 9781999 00dc59f 9390868 2bede7c a5b3080 098f871 7ab850e 512570f 3ad22ce 512570f 2bede7c 925d15e 512570f 925d15e b31944c 925d15e 2bede7c 512570f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import os
import tempfile
os.environ["HF_HUB_CACHE"] = "cache"
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
import gradio as gr
from huggingface_hub import HfApi
from huggingface_hub import whoami
from huggingface_hub import ModelCard
from huggingface_hub import scan_cache_dir
from huggingface_hub import logging
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from apscheduler.schedulers.background import BackgroundScheduler
from textwrap import dedent
import mlx_lm
from mlx_lm import convert
HF_TOKEN = os.environ.get("HF_TOKEN")
SPACE_ID = os.environ.get("SPACE_ID")
# I'm not sure if we need to add more stuff here
QUANT_PARAMS = {
"Q4": 4,
"Q6": 6,
"Q8": 8,
}
def list_files_in_folder(folder_path):
# List all files and directories in the specified folder
all_items = os.listdir(folder_path)
# Filter out only files
files = [item for item in all_items if os.path.isfile(os.path.join(folder_path, item))]
return files
def clear_hf_cache_space():
scan = scan_cache_dir()
to_delete = []
for repo in scan.repos:
if repo.repo_type == "model":
to_delete.extend([rev.commit_hash for rev in repo.revisions])
scan.delete_revisions(*to_delete).execute()
print("Cache has been cleared")
def upload_to_hub(path, upload_repo, hf_path, token):
card = ModelCard.load(hf_path, token=token)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx", "mlx-my-repo"]
card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{mlx_lm.__version__}**.
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
"""
)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi(token=token)
api.create_repo(repo_id=upload_repo, exist_ok=True)
files = list_files_in_folder(path)
print(files)
for file in files:
file_path = os.path.join(path, file)
print(f"Uploading file: {file_path}")
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=file,
repo_id=upload_repo,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
if oauth_token is None or oauth_token.token is None:
raise gr.Error("You must be logged in to use MLX-my-repo")
token = oauth_token.token
# Verify the token
username = None
try:
user_info = whoami(token=token)
username = user_info["name"]
print(f"✅ Logged in as {username}")
except Exception as e:
raise gr.Error(f"❌ Authentication failed: {e}")
try:
model_name = model_id.split('/')[-1]
repo_name = None
q_bits = None
if q_method == "FP16":
q_bits = "float16"
repo_name = f"{model_name}-fp16"
else:
q_bits = QUANT_PARAMS[q_method]
repo_name = f"{model_name}-{q_bits}bit"
upload_repo = f"mlx-community/{repo_name}"
with tempfile.TemporaryDirectory(dir=f"converted") as tmpdir:
# The target directory must not exist
mlx_path = os.path.join(tmpdir, "mlx")
if q_method == "FP16":
convert(model_id, mlx_path=mlx_path, quantize=False, dtype="float16")
else:
convert(model_id, mlx_path=mlx_path, quantize=True, q_bits=q_bits)
print("Conversion done")
upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, token=token)
print("Upload done")
return (
f'Find your repo <a href="https://hf.co/{upload_repo}" target="_blank" style="text-decoration:underline">here</a>',
"llama.png",
)
except Exception as e:
raise gr.Error(f"❌ Error: {e}")
finally:
clear_hf_cache_space()
print("Folder cleaned up successfully!")
css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
"""
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown("You must be logged in to use MLX-my-repo.")
gr.LoginButton()
model_id = HuggingfaceHubSearch(
label="Hub Model ID",
placeholder="Search for model id on Huggingface",
search_type="model",
)
q_method = gr.Dropdown(
["FP16", "Q2", "Q3", "Q4", "Q6", "Q8"],
label="Conversion Method",
info="MLX conversion type (FP16 for float16, Q2–Q8 for quantized models)",
value="Q4",
filterable=False,
visible=True
)
iface = gr.Interface(
fn=process_model,
inputs=[
model_id,
q_method,
],
outputs=[
gr.Markdown(label="output"),
gr.Image(show_label=False),
],
title="Create your own MLX Models, blazingly fast ⚡!",
description="The space takes an HF repo as an input, converts it to MLX format (FP16 or quantized), and creates a Public/Private repo under your HF user namespace.",
api_name=False
)
def restart_space():
HfApi().restart_space(repo_id=SPACE_ID, token=HF_TOKEN, factory_reboot=True)
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=21600)
scheduler.start()
# Launch the interface
demo.launch()
|