mergekit-gui / app.py
Hjgugugjhuhjggg's picture
Update app.py
234304a verified
raw
history blame
8.8 kB
import os
import pathlib
import random
import string
import tempfile
import time
import threading
from typing import Iterable, List
import gradio as gr
import huggingface_hub
import torch
import yaml
from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
from mergekit.config import MergeConfiguration
from huggingface_hub import spaces
has_gpu = torch.cuda.is_available()
cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
" --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
)
MARKDOWN_DESCRIPTION = """
# mergekit-gui
The fastest way to perform a model merge πŸ”₯
Specify a YAML configuration file (see examples below) and a HF token and this app will perform the merge and upload the merged model to your user profile.
"""
MARKDOWN_ARTICLE = """
___
## Merge Configuration
[Mergekit](https://github.com/arcee-ai/mergekit) configurations are YAML documents specifying the operations to perform in order to produce your merged model.
Below are the primary elements of a configuration file:
- `merge_method`: Specifies the method to use for merging models. See [Merge Methods](https://github.com/arcee-ai/mergekit#merge-methods) for a list.
- `slices`: Defines slices of layers from different models to be used. This field is mutually exclusive with `models`.
- `models`: Defines entire models to be used for merging. This field is mutually exclusive with `slices`.
- `base_model`: Specifies the base model used in some merging methods.
- `parameters`: Holds various parameters such as weights and densities, which can also be specified at different levels of the configuration.
- `dtype`: Specifies the data type used for the merging operation.
- `tokenizer_source`: Determines how to construct a tokenizer for the merged model.
## Merge Methods
A quick overview of the currently supported merge methods:
| Method | `merge_method` value | Multi-Model | Uses base model |
| -------------------------------------------------------------------------------------------- | -------------------- | ----------- | --------------- |
| Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` | βœ… | ❌ |
| SLERP | `slerp` | ❌ | βœ… |
| [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` | βœ… | βœ… |
| [TIES](https://arxiv.org/abs/2306.01708) | `ties` | βœ… | βœ… |
| [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` | βœ… | βœ… |
| [DARE](https://arxiv.org/abs/2311.03099) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `dare_linear` | βœ… | βœ… |
| Passthrough | `passthrough` | ❌ | ❌ |
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | βœ… | βœ… |
## Citation
This GUI is powered by [Arcee's MergeKit](https://arxiv.org/abs/2403.13257).
If you use it in your research, please cite the following paper:
@article{goddard2024arcee,
title={Arcee's MergeKit: A Toolkit for Merging Large Language Models},
author={Goddard, Charles and Siriwardhana, Shamane and Ehghaghi, Malikeh and Meyers, Luke and Karpukhin, Vlad and Benedict, Brian and McQuade, Mark and Solawetz, Jacob},
journal={arXiv preprint arXiv:2403.13257},
year={2024}
}
This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb)).
"""
examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
def merge_process(yaml_config, hf_token, repo_name, profile_name, logs_queue):
runner = LogsViewRunner(logs_queue)
if not yaml_config:
runner.log("Empty yaml, pick an example below", level="ERROR")
return
try:
merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
except Exception as e:
runner.log(f"Invalid yaml {e}", level="ERROR")
return
is_community_model = False
if not hf_token:
if "/" in repo_name and not repo_name.startswith("mergekit-community/"):
runner.log(f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.", level="ERROR")
return
runner.log("No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization.")
is_community_model = True
if not COMMUNITY_HF_TOKEN:
raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
hf_token = COMMUNITY_HF_TOKEN
api = huggingface_hub.HfApi(token=hf_token)
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
tmpdir = pathlib.Path(tmpdirname)
merged_path = tmpdir / "merged"
merged_path.mkdir(parents=True, exist_ok=True)
config_path = merged_path / "config.yaml"
config_path.write_text(yaml_config)
runner.log(f"Merge configuration saved in {config_path}")
if not repo_name:
runner.log("No repo name provided. Generating a random one.")
repo_name = f"{profile_name}/mergekit-{merge_config.merge_method}" if profile_name else f"mergekit-{merge_config.merge_method}"
repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
repo_name = repo_name.replace("/", "-").strip("-")
if is_community_model and not repo_name.startswith("mergekit-community/"):
repo_name = f"mergekit-community/{repo_name}"
try:
runner.log(f"Creating repo {repo_name}")
repo_url = api.create_repo(repo_name, exist_ok=True)
runner.log(f"Repo created: {repo_url}")
except Exception as e:
runner.log(f"Error creating repo {e}", level="ERROR")
return
tmp_env = os.environ.copy()
tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
full_cli = cli + f" --lora-merge-cache {tmpdirname}/.lora_cache"
runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env) # Removed yield from
if runner.exit_code != 0:
runner.log("Merge failed. Deleting repo as no model is uploaded.", level="ERROR")
api.delete_repo(repo_url.repo_id)
return
runner.log("Model merged successfully. Uploading to HF.")
runner.run_python(api.upload_folder, repo_id=repo_url.repo_id, folder_path=merged_path / "merge") # Removed yield from
runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
def merge(yaml_config, hf_token, repo_name, profile_name):
logs_queue = []
thread = threading.Thread(target=merge_process, args=(yaml_config, hf_token, repo_name, profile_name, logs_queue))
thread.start()
while thread.is_alive():
if logs_queue:
yield logs_queue[:]
logs_queue.clear()
time.sleep(0.1)
if logs_queue:
yield logs_queue
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
filename = gr.Textbox(visible=False, label="filename")
config = gr.Code(language="yaml", lines=10, label="config.yaml")
with gr.Column():
token = gr.Textbox(lines=1, label="HF Write Token", info="https://hf.co/settings/token", type="password", placeholder="Optional. Will upload merged model to MergeKit Community if empty.")
repo_name = gr.Textbox(lines=1, label="Repo name", placeholder="Optional. Will create a random name if empty.")
profile_name = gr.Textbox(lines=1, label="Hugging Face Profile Name", placeholder="Enter your Hugging Face profile name.")
button = gr.Button("Merge", variant="primary")
logs = LogsView(label="Terminal output")
gr.Examples(examples, fn=lambda s: (s,), run_on_click=True, label="Examples", inputs=[filename], outputs=[config])
gr.Markdown(MARKDOWN_ARTICLE)
button.click(fn=merge, inputs=[config, token, repo_name, profile_name], outputs=[logs])
@spaces.GPU
def launch():
demo.launch(share=True)
if __name__ == "__main__":
launch()