Donmill commited on
Commit
b814ffe
·
verified ·
1 Parent(s): ac2d4fb

Created app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import zipfile
5
+ from pathlib import Path
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
+ import traceback
9
+
10
+ # --- Configuration for output paths ---
11
+ # This directory will store the quantized models temporarily on the Space
12
+ OUTPUT_DIR = Path("quantized_models_output")
13
+ OUTPUT_DIR.mkdir(exist_ok=True)
14
+
15
+ # --- The core quantization function ---
16
+ def quantize_model(model_id_or_path: str, quantization_level: str) -> gr.File:
17
+ """
18
+ Loads an AI model (from Hugging Face Hub or local path), quantizes it
19
+ based on the specified level, and saves the quantized model.
20
+ The quantized model directory is then zipped for easier download.
21
+
22
+ Args:
23
+ model_id_or_path: The Hugging Face model ID (e.g., "stabilityai/stablelm-zephyr-3b")
24
+ or a local path to a model directory (less common for HF Spaces,
25
+ but useful if you pre-upload models to the Space itself).
26
+ quantization_level: String indicating the desired quantization (e.g., '8-bit (INT8)', '4-bit (INT4)').
27
+
28
+ Returns:
29
+ A Gradio File object pointing to the path of the saved quantized model directory (as a zip).
30
+ """
31
+ if not model_id_or_path:
32
+ raise gr.Error("Please provide a Hugging Face Model ID or a path to a local model directory.")
33
+
34
+ print(f"[{model_id_or_path}] Attempting to quantize model.")
35
+ print(f"[{model_id_or_path}] Desired quantization level: {quantization_level}")
36
+
37
+ # Create a unique name for the saved quantized model directory
38
+ safe_model_name = model_id_or_path.replace('/', '__').replace('\\', '__').replace('.', '_')
39
+ quantized_model_base_name = f"quantized_{safe_model_name}_{quantization_level.replace(' ', '_').replace('(', '').replace(')', '')}"
40
+ quantized_model_save_path = OUTPUT_DIR / quantized_model_base_name
41
+
42
+ try:
43
+ # Determine quantization configuration based on selection
44
+ bnb_config = None
45
+ if "8-bit" in quantization_level:
46
+ print(f"[{model_id_or_path}] Configuring for 8-bit quantization (NF8).")
47
+ bnb_config = BitsAndBytesConfig(
48
+ load_in_8bit=True,
49
+ bnb_8bit_quant_type="nf8", # Default for 8-bit
50
+ bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
51
+ )
52
+ elif "4-bit" in quantization_level:
53
+ print(f"[{model_id_or_path}] Configuring for 4-bit quantization (NF4).")
54
+ bnb_config = BitsAndBytesConfig(
55
+ load_in_4bit=True,
56
+ bnb_4bit_quant_type="nf4",
57
+ bnb_4bit_use_double_quant=True, # More memory savings
58
+ bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
59
+ )
60
+ elif "FP16" in quantization_level:
61
+ print(f"[{model_id_or_path}] Configuring for FP16 (Half-Precision).")
62
+ # For FP16, we mainly rely on `torch_dtype` during `from_pretrained`
63
+ # and no BitsAndBytesConfig is directly needed for loading
64
+ pass # No bnb_config needed for direct FP16 load
65
+ else:
66
+ raise gr.Error(f"Unsupported quantization level: {quantization_level}")
67
+
68
+ # --- Load Model and Tokenizer ---
69
+ print(f"[{model_id_or_path}] Loading model and tokenizer from: {model_id_or_path}...")
70
+
71
+ # Determine the torch_dtype based on GPU availability and quantization level
72
+ load_torch_dtype = torch.float32 # Default
73
+ if torch.cuda.is_available():
74
+ if "FP16" in quantization_level:
75
+ load_torch_dtype = torch.float16
76
+ elif bnb_config and bnb_config.bnb_4bit_compute_dtype:
77
+ load_torch_dtype = bnb_config.bnb_4bit_compute_dtype # Use bfloat16 for 4/8-bit if set
78
+
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ model_id_or_path,
81
+ quantization_config=bnb_config, # Will be None for FP16, used for 4/8-bit
82
+ device_map="auto", # Automatically assigns layers to available devices (CPU/GPU)
83
+ torch_dtype=load_torch_dtype,
84
+ # trust_remote_code=True # Uncomment ONLY if you trust the model and it has custom code
85
+ )
86
+ tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
87
+ print(f"[{model_id_or_path}] Model and Tokenizer loaded successfully.")
88
+
89
+ # --- Save the Quantized Model ---
90
+ # First, clean up any previous runs of this specific model's quantized output
91
+ if quantized_model_save_path.exists():
92
+ print(f"[{model_id_or_path}] Cleaning up previous output directory: {quantized_model_save_path}")
93
+ shutil.rmtree(quantized_model_save_path)
94
+
95
+ model.save_pretrained(quantized_model_save_path)
96
+ tokenizer.save_pretrained(quantized_model_save_path)
97
+ print(f"[{model_id_or_path}] Quantized model and tokenizer saved to: {quantized_model_save_path}")
98
+
99
+ # Zip the directory for easy download
100
+ # shutil.make_archive automatically adds a .zip extension
101
+ zip_file_path = shutil.make_archive(
102
+ base_name=str(quantized_model_save_path),
103
+ format='zip',
104
+ root_dir=str(quantized_model_save_path)
105
+ )
106
+ print(f"[{model_id_or_path}] Quantized model zipped to: {zip_file_path}")
107
+
108
+ # Return the path to the zipped file for Gradio to make downloadable
109
+ return gr.File(value=zip_file_path, filename=Path(zip_file_path).name, label="Download Quantized Model (ZIP)")
110
+
111
+ except Exception as e:
112
+ print(f"[{model_id_or_path}] An error occurred during quantization: {e}")
113
+ traceback.print_exc() # Print full traceback for debugging in the Space logs
114
+ raise gr.Error(f"Quantization failed! Error: {e}. Check the Hugging Face Space logs for details. "
115
+ "Ensure you have a CUDA-enabled GPU for 8/4-bit quantization, "
116
+ "and that the model is compatible.")
117
+
118
+ # --- Gradio Interface Definition ---
119
+ iface = gr.Interface(
120
+ fn=quantize_model,
121
+ inputs=[
122
+ gr.Textbox(label="Hugging Face Model ID (e.g., stabilityai/stablelm-zephyr-3b)",
123
+ placeholder="Enter a model ID from Hugging Face Hub (e.g., meta-llama/Llama-2-7b-hf)"),
124
+ gr.Dropdown(
125
+ choices=["8-bit (INT8)", "4-bit (INT4)", "FP16 (Half-Precision)"],
126
+ label="Select Quantization Level",
127
+ value="8-bit (INT8)" # Default selection
128
+ )
129
+ ],
130
+ outputs=gr.File(label="Quantized Model Download"),
131
+ title="🌌 AI Model Shrinker: Quantize Your Models!",
132
+ description=(
133
+ "Enter a Hugging Face Model ID to effortlessly quantize it and reduce its size and memory footprint. "
134
+ "This can significantly improve inference speed and allow larger models to run on more modest hardware. "
135
+ "<br><b>Important Notes:</b>"
136
+ "<ul>"
137
+ "<li><b>GPU Required:</b> 8-bit and 4-bit quantization (using `bitsandbytes`) require a **CUDA-enabled GPU** to work properly. Choose a GPU hardware tier for your Space.</li>"
138
+ "<li><b>Compatibility:</b> Not all models are guaranteed to work perfectly after quantization, especially 4-bit. Performance might vary.</li>"
139
+ "<li><b>Downloading:</b> The output will be a `.zip` file containing the quantized model's directory.</li>"
140
+ "<li><b>Experimental:</b> Embrace the experimental spirit! This tool pushes boundaries in AI accessibility.</li>"
141
+ "</ul>"
142
+ ),
143
+ live=False, # Set to True if you want live updates, but not ideal for long processes
144
+ allow_flagging="manual", # Allows users to flag inputs/outputs, useful for debugging
145
+ )
146
+
147
+ # Launch the Gradio app
148
+ if __name__ == "__main__":
149
+ # When running locally, share=True creates a public URL for easy sharing
150
+ # On Hugging Face Spaces, this is handled automatically.
151
+ iface.launch(share=True)
152
+