import gradio as gr
import os
import shutil
import zipfile
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import traceback
# --- Configuration for output paths ---
# This directory will store the quantized models temporarily on the Space
OUTPUT_DIR = Path("quantized_models_output")
OUTPUT_DIR.mkdir(exist_ok=True)
# --- The core quantization function ---
def quantize_model(model_id_or_path: str, quantization_level: str) -> gr.File:
"""
Loads an AI model (from Hugging Face Hub or local path), quantizes it
based on the specified level, and saves the quantized model.
The quantized model directory is then zipped for easier download.
Args:
model_id_or_path: The Hugging Face model ID (e.g., "stabilityai/stablelm-zephyr-3b")
or a local path to a model directory (less common for HF Spaces,
but useful if you pre-upload models to the Space itself).
quantization_level: String indicating the desired quantization (e.g., '8-bit (INT8)', '4-bit (INT4)').
Returns:
A Gradio File object pointing to the path of the saved quantized model directory (as a zip).
"""
if not model_id_or_path:
raise gr.Error("Please provide a Hugging Face Model ID or a path to a local model directory.")
print(f"[{model_id_or_path}] Attempting to quantize model.")
print(f"[{model_id_or_path}] Desired quantization level: {quantization_level}")
# Create a unique name for the saved quantized model directory
safe_model_name = model_id_or_path.replace('/', '__').replace('\\', '__').replace('.', '_')
quantized_model_base_name = f"quantized_{safe_model_name}_{quantization_level.replace(' ', '_').replace('(', '').replace(')', '')}"
quantized_model_save_path = OUTPUT_DIR / quantized_model_base_name
try:
# Determine quantization configuration based on selection
bnb_config = None
if "8-bit" in quantization_level:
print(f"[{model_id_or_path}] Configuring for 8-bit quantization (NF8).")
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_quant_type="nf8", # Default for 8-bit
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
)
elif "4-bit" in quantization_level:
print(f"[{model_id_or_path}] Configuring for 4-bit quantization (NF4).")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True, # More memory savings
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
)
elif "FP16" in quantization_level:
print(f"[{model_id_or_path}] Configuring for FP16 (Half-Precision).")
# For FP16, we mainly rely on `torch_dtype` during `from_pretrained`
# and no BitsAndBytesConfig is directly needed for loading
pass # No bnb_config needed for direct FP16 load
else:
raise gr.Error(f"Unsupported quantization level: {quantization_level}")
# --- Load Model and Tokenizer ---
print(f"[{model_id_or_path}] Loading model and tokenizer from: {model_id_or_path}...")
# Determine the torch_dtype based on GPU availability and quantization level
load_torch_dtype = torch.float32 # Default
if torch.cuda.is_available():
if "FP16" in quantization_level:
load_torch_dtype = torch.float16
elif bnb_config and bnb_config.bnb_4bit_compute_dtype:
load_torch_dtype = bnb_config.bnb_4bit_compute_dtype # Use bfloat16 for 4/8-bit if set
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
quantization_config=bnb_config, # Will be None for FP16, used for 4/8-bit
device_map="auto", # Automatically assigns layers to available devices (CPU/GPU)
torch_dtype=load_torch_dtype,
# trust_remote_code=True # Uncomment ONLY if you trust the model and it has custom code
)
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
print(f"[{model_id_or_path}] Model and Tokenizer loaded successfully.")
# --- Save the Quantized Model ---
# First, clean up any previous runs of this specific model's quantized output
if quantized_model_save_path.exists():
print(f"[{model_id_or_path}] Cleaning up previous output directory: {quantized_model_save_path}")
shutil.rmtree(quantized_model_save_path)
model.save_pretrained(quantized_model_save_path)
tokenizer.save_pretrained(quantized_model_save_path)
print(f"[{model_id_or_path}] Quantized model and tokenizer saved to: {quantized_model_save_path}")
# Zip the directory for easy download
# shutil.make_archive automatically adds a .zip extension
zip_file_path = shutil.make_archive(
base_name=str(quantized_model_save_path),
format='zip',
root_dir=str(quantized_model_save_path)
)
print(f"[{model_id_or_path}] Quantized model zipped to: {zip_file_path}")
# Return the path to the zipped file for Gradio to make downloadable
return gr.File(value=zip_file_path, filename=Path(zip_file_path).name, label="Download Quantized Model (ZIP)")
except Exception as e:
print(f"[{model_id_or_path}] An error occurred during quantization: {e}")
traceback.print_exc() # Print full traceback for debugging in the Space logs
raise gr.Error(f"Quantization failed! Error: {e}. Check the Hugging Face Space logs for details. "
"Ensure you have a CUDA-enabled GPU for 8/4-bit quantization, "
"and that the model is compatible.")
# --- Gradio Interface Definition ---
iface = gr.Interface(
fn=quantize_model,
inputs=[
gr.Textbox(label="Hugging Face Model ID (e.g., stabilityai/stablelm-zephyr-3b)",
placeholder="Enter a model ID from Hugging Face Hub (e.g., meta-llama/Llama-2-7b-hf)"),
gr.Dropdown(
choices=["8-bit (INT8)", "4-bit (INT4)", "FP16 (Half-Precision)"],
label="Select Quantization Level",
value="8-bit (INT8)" # Default selection
)
],
outputs=gr.File(label="Quantized Model Download"),
title="🌌 AI Model Shrinker: Quantize Your Models!",
description=(
"Enter a Hugging Face Model ID to effortlessly quantize it and reduce its size and memory footprint. "
"This can significantly improve inference speed and allow larger models to run on more modest hardware. "
"
Important Notes:"
"