Spaces:
Running
Running
import gradio as gr | |
import os | |
import shutil | |
import zipfile | |
from pathlib import Path | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import traceback | |
# --- Configuration for output paths --- | |
# This directory will store the quantized models temporarily on the Space | |
OUTPUT_DIR = Path("quantized_models_output") | |
OUTPUT_DIR.mkdir(exist_ok=True) | |
# --- The core quantization function --- | |
def quantize_model(model_id_or_path: str, quantization_level: str) -> gr.File: | |
""" | |
Loads an AI model (from Hugging Face Hub or local path), quantizes it | |
based on the specified level, and saves the quantized model. | |
The quantized model directory is then zipped for easier download. | |
Args: | |
model_id_or_path: The Hugging Face model ID (e.g., "stabilityai/stablelm-zephyr-3b") | |
or a local path to a model directory (less common for HF Spaces, | |
but useful if you pre-upload models to the Space itself). | |
quantization_level: String indicating the desired quantization (e.g., '8-bit (INT8)', '4-bit (INT4)'). | |
Returns: | |
A Gradio File object pointing to the path of the saved quantized model directory (as a zip). | |
""" | |
if not model_id_or_path: | |
raise gr.Error("Please provide a Hugging Face Model ID or a path to a local model directory.") | |
print(f"[{model_id_or_path}] Attempting to quantize model.") | |
print(f"[{model_id_or_path}] Desired quantization level: {quantization_level}") | |
# Create a unique name for the saved quantized model directory | |
safe_model_name = model_id_or_path.replace('/', '__').replace('\\', '__').replace('.', '_') | |
quantized_model_base_name = f"quantized_{safe_model_name}_{quantization_level.replace(' ', '_').replace('(', '').replace(')', '')}" | |
quantized_model_save_path = OUTPUT_DIR / quantized_model_base_name | |
try: | |
# Determine quantization configuration based on selection | |
bnb_config = None | |
if "8-bit" in quantization_level: | |
print(f"[{model_id_or_path}] Configuring for 8-bit quantization (NF8).") | |
bnb_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_quant_type="nf8", # Default for 8-bit | |
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None, | |
) | |
elif "4-bit" in quantization_level: | |
print(f"[{model_id_or_path}] Configuring for 4-bit quantization (NF4).") | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, # More memory savings | |
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None, | |
) | |
elif "FP16" in quantization_level: | |
print(f"[{model_id_or_path}] Configuring for FP16 (Half-Precision).") | |
# For FP16, we mainly rely on `torch_dtype` during `from_pretrained` | |
# and no BitsAndBytesConfig is directly needed for loading | |
pass # No bnb_config needed for direct FP16 load | |
else: | |
raise gr.Error(f"Unsupported quantization level: {quantization_level}") | |
# --- Load Model and Tokenizer --- | |
print(f"[{model_id_or_path}] Loading model and tokenizer from: {model_id_or_path}...") | |
# Determine the torch_dtype based on GPU availability and quantization level | |
load_torch_dtype = torch.float32 # Default | |
if torch.cuda.is_available(): | |
if "FP16" in quantization_level: | |
load_torch_dtype = torch.float16 | |
elif bnb_config and bnb_config.bnb_4bit_compute_dtype: | |
load_torch_dtype = bnb_config.bnb_4bit_compute_dtype # Use bfloat16 for 4/8-bit if set | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id_or_path, | |
quantization_config=bnb_config, # Will be None for FP16, used for 4/8-bit | |
device_map="auto", # Automatically assigns layers to available devices (CPU/GPU) | |
torch_dtype=load_torch_dtype, | |
# trust_remote_code=True # Uncomment ONLY if you trust the model and it has custom code | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path) | |
print(f"[{model_id_or_path}] Model and Tokenizer loaded successfully.") | |
# --- Save the Quantized Model --- | |
# First, clean up any previous runs of this specific model's quantized output | |
if quantized_model_save_path.exists(): | |
print(f"[{model_id_or_path}] Cleaning up previous output directory: {quantized_model_save_path}") | |
shutil.rmtree(quantized_model_save_path) | |
model.save_pretrained(quantized_model_save_path) | |
tokenizer.save_pretrained(quantized_model_save_path) | |
print(f"[{model_id_or_path}] Quantized model and tokenizer saved to: {quantized_model_save_path}") | |
# Zip the directory for easy download | |
# shutil.make_archive automatically adds a .zip extension | |
zip_file_path = shutil.make_archive( | |
base_name=str(quantized_model_save_path), | |
format='zip', | |
root_dir=str(quantized_model_save_path) | |
) | |
print(f"[{model_id_or_path}] Quantized model zipped to: {zip_file_path}") | |
# Return the path to the zipped file for Gradio to make downloadable | |
return gr.File(value=zip_file_path, filename=Path(zip_file_path).name, label="Download Quantized Model (ZIP)") | |
except Exception as e: | |
print(f"[{model_id_or_path}] An error occurred during quantization: {e}") | |
traceback.print_exc() # Print full traceback for debugging in the Space logs | |
raise gr.Error(f"Quantization failed! Error: {e}. Check the Hugging Face Space logs for details. " | |
"Ensure you have a CUDA-enabled GPU for 8/4-bit quantization, " | |
"and that the model is compatible.") | |
# --- Gradio Interface Definition --- | |
iface = gr.Interface( | |
fn=quantize_model, | |
inputs=[ | |
gr.Textbox(label="Hugging Face Model ID (e.g., stabilityai/stablelm-zephyr-3b)", | |
placeholder="Enter a model ID from Hugging Face Hub (e.g., meta-llama/Llama-2-7b-hf)"), | |
gr.Dropdown( | |
choices=["8-bit (INT8)", "4-bit (INT4)", "FP16 (Half-Precision)"], | |
label="Select Quantization Level", | |
value="8-bit (INT8)" # Default selection | |
) | |
], | |
outputs=gr.File(label="Quantized Model Download"), | |
title="🌌 AI Model Shrinker: Quantize Your Models!", | |
description=( | |
"Enter a Hugging Face Model ID to effortlessly quantize it and reduce its size and memory footprint. " | |
"This can significantly improve inference speed and allow larger models to run on more modest hardware. " | |
"<br><b>Important Notes:</b>" | |
"<ul>" | |
"<li><b>GPU Required:</b> 8-bit and 4-bit quantization (using `bitsandbytes`) require a **CUDA-enabled GPU** to work properly. Choose a GPU hardware tier for your Space.</li>" | |
"<li><b>Compatibility:</b> Not all models are guaranteed to work perfectly after quantization, especially 4-bit. Performance might vary.</li>" | |
"<li><b>Downloading:</b> The output will be a `.zip` file containing the quantized model's directory.</li>" | |
"<li><b>Experimental:</b> Embrace the experimental spirit! This tool pushes boundaries in AI accessibility.</li>" | |
"</ul>" | |
), | |
live=False, # Set to True if you want live updates, but not ideal for long processes | |
allow_flagging="manual", # Allows users to flag inputs/outputs, useful for debugging | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
# When running locally, share=True creates a public URL for easy sharing | |
# On Hugging Face Spaces, this is handled automatically. | |
iface.launch(share=True) | |