import gradio as gr from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM from huggingface_hub import cached_download, hf_hub_url, list_models, create_repo, HfApi from transformers.modeling_utils import PreTrainedModel import requests import json import os import matplotlib.pyplot as plt from io import BytesIO import base64 import torch from torch.nn.utils import prune # Function to fetch open-weight LLM models def fetch_open_weight_models(): models = list_models() return models # Function to prune a model using the "merge-kit" approach def prune_model(llm_model_name, target_size, hf_write_token, repo_name): try: # Load the LLM model and tokenizer llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name) # Handle cases where the model is split into multiple safetensors llm_model = AutoModelForCausalLM.from_pretrained( llm_model_name, torch_dtype=torch.float16, # Adjust dtype as needed ) # Get the model config config = AutoConfig.from_pretrained(llm_model_name) # Calculate the target number of parameters target_num_parameters = int(config.num_parameters * (target_size / 100)) # Use merge-kit to prune the model pruned_model = merge_kit_prune(llm_model, target_num_parameters) # Save the pruned model to Hugging Face repository api = HfApi() repo_id = f"{hf_write_token}/{repo_name}" create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True) pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token) llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token) # Create a visualization fig, ax = plt.subplots(figsize=(10, 5)) ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters]) ax.set_ylabel("Number of Parameters") ax.set_title("Model Size Comparison") buf = BytesIO() fig.savefig(buf, format="png") buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode("utf-8") return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}" except Exception as e: return f"Error: {e}", None # Merge-kit Pruning Function (adjust as needed) def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel: """Prunes a model using a merge-kit approach. Args: model (PreTrainedModel): The model to be pruned. target_num_parameters (int): The target number of parameters after pruning. Returns: PreTrainedModel: The pruned model. """ # Define the pruning method pruning_method = "unstructured" # Calculate the pruning amount amount = 1 - (target_num_parameters / sum(p.numel() for p in model.parameters())) # Prune the model using the selected method for name, module in model.named_modules(): if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): prune.random_unstructured(module, name="weight", amount=amount) # Remove the pruned weights for name, module in model.named_modules(): if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): prune.remove(module, name="weight") return model # Function to create a Gradio interface def create_interface(): with gr.Blocks() as demo: gr.Markdown("## Create a Smaller LLM") # Input for model name llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True) # Input for target model size target_size = gr.Slider( label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True, ) # Input for Hugging Face write token hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password") # Input for repository name repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True) # Output for pruning status pruning_status = gr.Textbox(label="Pruning Status", interactive=False) # Button to start pruning prune_button = gr.Button("Prune Model") # Output for visualization visualization = gr.Image(label="Model Size Comparison", interactive=False) # Connect components prune_button.click( fn=prune_model, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization], ) # Example usage of the pruned model (optional) text_input = gr.Textbox(label="Input Text") text_output = gr.Textbox(label="Generated Text") # Generate text button generate_button = gr.Button("Generate Text") def generate_text(text, repo_name): try: # Load the pruned model and tokenizer tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token) model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token) # Use the pipeline for text generation generator = pipeline("text-generation", model=model, tokenizer=tokenizer) generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"] return generated_text except Exception as e: return f"Error: {e}" generate_button.click(fn=generate_text, inputs=[text_input, repo_name], outputs=text_output) return demo # Create and launch the Gradio interface demo = create_interface() demo.launch(share=True)