Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from huggingface_hub import HfApi | |
import subprocess | |
def create_config_yaml( | |
model_name, | |
model1, | |
model1_layers, | |
model2, | |
model2_layers, | |
merge_method, | |
base_model, | |
parameters, | |
dtype, | |
): | |
yaml_config = ( | |
f" slices:\n" | |
" - sources:\n" | |
f" - model: {model1}\n" | |
f" layer_range: {model1_layers}\n" | |
f" - model: {model2}\n" | |
f" layer_range: {model2_layers}\n" | |
f" merge_method: {merge_method}\n" | |
f" base_model: {base_model}\n" | |
f" parameters:\n" | |
f" {parameters}\n" | |
f" dtype: {dtype}\n" | |
) | |
print("Writing YAML config to 'config.yaml'...") | |
try: | |
with open("config.yaml", "w", encoding="utf-8") as f: | |
f.write(yaml_config) | |
print("File 'config.yaml' written successfully.") | |
except Exception as e: | |
print(f"Error writing file: {e}") | |
return yaml_config | |
def execute_merge_command(): | |
# Define the command and arguments | |
command = "mergekit-yaml" | |
args = ["config.yaml", "./output-model-directory"] | |
# Execute the command | |
result = subprocess.run([command] + args, capture_output=True, text=True) | |
# Check if the command was executed successfully | |
if result.returncode == 0: | |
print("Command executed successfully") | |
return f"Output:\n{result.stdout}" | |
else: | |
print("Error in executing command") | |
return f"Error:\n{result.stderr}" | |
# Function to push to HF Hub (for the third tab) | |
def push_to_hf_hub(model_name, yaml_config): | |
# Username and API token setup | |
username = "arcee-ai" | |
api_token = os.getenv("HF_TOKEN") | |
if api_token is None: | |
return "Hugging Face API token not set. Please set the HF_TOKEN environment variable." | |
# Initialize HfApi with token | |
api = HfApi(token=api_token) | |
repo_id = f"{username}/{model_name}" | |
try: | |
# Create a new repository on Hugging Face | |
api.create_repo(repo_id=repo_id, repo_type="model") | |
# For demonstration, let's just create a yaml file inside a folder | |
# os.makedirs("merge", exist_ok=True) | |
with open("config.yaml", "w") as file: | |
file.write(yaml_config) | |
# Upload the contents of the 'merge' folder to the repository | |
api.upload_folder(repo_id=repo_id, folder_path="merge") | |
return f"Successfully pushed to HF Hub: {repo_id}" | |
except Exception as e: | |
return str(e) | |
# make sure to add the themes as well | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo")) as app: | |
gr.Markdown("# Mergekit GUI") # Title for your Gradio app | |
with gr.Tab("Config YAML"): | |
# Inputs for the YAML config | |
with gr.Row(): | |
model_name_input = gr.Textbox(label="Model Name") | |
model1_input = gr.Textbox(label="Model 1") | |
model1_layers_input = gr.Textbox( | |
label="Model 1 Layer Range", placeholder="[start, end]" | |
) | |
model2_input = gr.Textbox(label="Model 2") | |
model2_layers_input = gr.Textbox( | |
label="Model 2 Layer Range", placeholder="[start, end]" | |
) | |
merge_method_input = gr.Dropdown( | |
label="Merge Method", choices=["slerp", "linear"] | |
) | |
base_model_input = gr.Textbox(label="Base Model") | |
parameters_input = gr.Textbox( | |
label="Parameters", placeholder="Formatted as a list of dicts" | |
) | |
dtype_input = gr.Textbox(label="Data Type", value="bfloat16") | |
create_button = gr.Button("Create Config YAML") | |
create_button.click( | |
fn=create_config_yaml, | |
inputs=[ | |
model_name_input, | |
model1_input, | |
model1_layers_input, | |
model2_input, | |
model2_layers_input, | |
merge_method_input, | |
base_model_input, | |
parameters_input, | |
dtype_input, | |
], | |
outputs=[], | |
) | |
with gr.Tab("Merge"): | |
# Placeholder for Merge tab contents | |
# Not yet tested | |
merge_output = gr.Textbox(label="Merge Output", interactive=False) | |
merge_button = gr.Button("Execute Merge Command") | |
merge_button.click(fn=execute_merge_command, inputs=[], outputs=merge_output) | |
with gr.Tab("Push to HF Hub"): | |
push_model_name_input = gr.Textbox(label="Model Name", interactive=False) | |
push_yaml_config_input = gr.Textbox(label="YAML Config", interactive=False) | |
push_output = gr.Textbox(label="Push Output", interactive=False) | |
push_button = gr.Button("Push to HF Hub") | |
push_button.click( | |
fn=push_to_hf_hub, | |
inputs=[push_model_name_input, push_yaml_config_input], | |
outputs=push_output, | |
) | |
app.launch() | |