oobabooga commited on
Commit
6441476
·
verified ·
1 Parent(s): 8f3a4a9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +432 -0
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from math import exp
3
+ import re
4
+ import struct
5
+ import requests
6
+ import io
7
+ from enum import IntEnum
8
+
9
+
10
+ class GGUFValueType(IntEnum):
11
+ UINT8 = 0
12
+ INT8 = 1
13
+ UINT16 = 2
14
+ INT16 = 3
15
+ UINT32 = 4
16
+ INT32 = 5
17
+ FLOAT32 = 6
18
+ BOOL = 7
19
+ STRING = 8
20
+ ARRAY = 9
21
+ UINT64 = 10
22
+ INT64 = 11
23
+ FLOAT64 = 12
24
+
25
+
26
+ _simple_value_packing = {
27
+ GGUFValueType.UINT8: "<B",
28
+ GGUFValueType.INT8: "<b",
29
+ GGUFValueType.UINT16: "<H",
30
+ GGUFValueType.INT16: "<h",
31
+ GGUFValueType.UINT32: "<I",
32
+ GGUFValueType.INT32: "<i",
33
+ GGUFValueType.FLOAT32: "<f",
34
+ GGUFValueType.UINT64: "<Q",
35
+ GGUFValueType.INT64: "<q",
36
+ GGUFValueType.FLOAT64: "<d",
37
+ GGUFValueType.BOOL: "?",
38
+ }
39
+
40
+ value_type_info = {
41
+ GGUFValueType.UINT8: 1,
42
+ GGUFValueType.INT8: 1,
43
+ GGUFValueType.UINT16: 2,
44
+ GGUFValueType.INT16: 2,
45
+ GGUFValueType.UINT32: 4,
46
+ GGUFValueType.INT32: 4,
47
+ GGUFValueType.FLOAT32: 4,
48
+ GGUFValueType.UINT64: 8,
49
+ GGUFValueType.INT64: 8,
50
+ GGUFValueType.FLOAT64: 8,
51
+ GGUFValueType.BOOL: 1,
52
+ }
53
+
54
+
55
+ def get_single(value_type, file):
56
+ if value_type == GGUFValueType.STRING:
57
+ value_length = struct.unpack("<Q", file.read(8))[0]
58
+ value = file.read(value_length)
59
+ try:
60
+ value = value.decode('utf-8')
61
+ except:
62
+ pass
63
+ else:
64
+ type_str = _simple_value_packing.get(value_type)
65
+ bytes_length = value_type_info.get(value_type)
66
+ value = struct.unpack(type_str, file.read(bytes_length))[0]
67
+
68
+ return value
69
+
70
+
71
+ def load_metadata_from_file(file_obj):
72
+ """Load metadata from a file-like object"""
73
+ metadata = {}
74
+
75
+ GGUF_MAGIC = struct.unpack("<I", file_obj.read(4))[0]
76
+ GGUF_VERSION = struct.unpack("<I", file_obj.read(4))[0]
77
+ ti_data_count = struct.unpack("<Q", file_obj.read(8))[0]
78
+ kv_data_count = struct.unpack("<Q", file_obj.read(8))[0]
79
+
80
+ if GGUF_VERSION == 1:
81
+ raise Exception('You are using an outdated GGUF, please download a new one.')
82
+
83
+ for i in range(kv_data_count):
84
+ key_length = struct.unpack("<Q", file_obj.read(8))[0]
85
+ key = file_obj.read(key_length)
86
+
87
+ value_type = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0])
88
+ if value_type == GGUFValueType.ARRAY:
89
+ ltype = GGUFValueType(struct.unpack("<I", file_obj.read(4))[0])
90
+ length = struct.unpack("<Q", file_obj.read(8))[0]
91
+
92
+ arr = [get_single(ltype, file_obj) for _ in range(length)]
93
+ metadata[key.decode()] = arr
94
+ else:
95
+ value = get_single(value_type, file_obj)
96
+ metadata[key.decode()] = value
97
+
98
+ # Extract specific fields needed for VRAM calculation
99
+ extracted_fields = {}
100
+ for key, value in metadata.items():
101
+ if key.endswith('.block_count'):
102
+ extracted_fields['n_layers'] = value
103
+ elif key.endswith('.attention.head_count_kv'):
104
+ extracted_fields['n_kv_heads'] = value
105
+ elif key.endswith('.embedding_length'):
106
+ extracted_fields['embedding_dim'] = value
107
+ elif key.endswith('.context_length'):
108
+ extracted_fields['context_length'] = value
109
+ elif key.endswith('.feed_forward_length'):
110
+ extracted_fields['feed_forward_dim'] = value
111
+
112
+ # Add extracted fields to metadata for easy access
113
+ metadata.update(extracted_fields)
114
+ return metadata
115
+
116
+
117
+ def download_gguf_partial(url, max_bytes=25 * 1024 * 1024):
118
+ """Download the first max_bytes from a GGUF URL"""
119
+ try:
120
+ # Set up headers for partial content request
121
+ headers = {'Range': f'bytes=0-{max_bytes-1}'}
122
+
123
+ # Make the request
124
+ response = requests.get(url, headers=headers, stream=True)
125
+ response.raise_for_status()
126
+
127
+ # Read the content
128
+ content = response.content
129
+
130
+ # Convert to BytesIO for file-like interface
131
+ return io.BytesIO(content)
132
+
133
+ except Exception as e:
134
+ raise Exception(f"Failed to download GGUF file: {str(e)}")
135
+
136
+
137
+ def load_metadata(model_url, current_metadata):
138
+ """Load metadata from model URL and return updated metadata dict"""
139
+ if not model_url or model_url.strip() == "":
140
+ return {}, "Please enter a model URL"
141
+
142
+ try:
143
+ # Get model size first
144
+ model_size_mb = get_model_size_mb_from_url(model_url)
145
+
146
+ # Normalize URL for downloading
147
+ normalized_url = normalize_huggingface_url(model_url)
148
+
149
+ # Download the first 25MB of the file
150
+ file_obj = download_gguf_partial(normalized_url)
151
+
152
+ # Parse the metadata
153
+ metadata = load_metadata_from_file(file_obj)
154
+
155
+ # Extract model name from URL if it's a Hugging Face URL
156
+ model_name = model_url
157
+ if "huggingface.co/" in model_url:
158
+ try:
159
+ # Extract model name from URL like https://huggingface.co/user/model
160
+ parts = model_url.split("huggingface.co/")[1].split("/")
161
+ if len(parts) >= 2:
162
+ model_name = f"{parts[0]}/{parts[1]}"
163
+ except:
164
+ model_name = model_url
165
+
166
+ # Add URL, model name, and size to metadata
167
+ metadata['url'] = model_url
168
+ metadata['model_name'] = model_name
169
+ metadata['model_size_mb'] = model_size_mb
170
+ metadata['loaded'] = True
171
+
172
+ return metadata, gr.update(value=metadata["n_layers"], maximum=metadata["n_layers"]), f"Metadata loaded successfully for: {model_name}"
173
+
174
+ except Exception as e:
175
+ error_msg = f"Error loading metadata: {str(e)}"
176
+ return {}, gr.update(), error_msg
177
+
178
+
179
+ def normalize_huggingface_url(url: str) -> str:
180
+ """Normalize HuggingFace URL to resolve format for direct access"""
181
+ if 'huggingface.co' not in url:
182
+ return url
183
+
184
+ # Remove query parameters first
185
+ base_url = url.split('?')[0]
186
+
187
+ # Convert blob URL to resolve URL
188
+ if '/blob/' in base_url:
189
+ base_url = base_url.replace('/blob/', '/resolve/')
190
+
191
+ return base_url
192
+
193
+
194
+ def get_model_size_mb_from_url(model_url: str) -> float:
195
+ """Get model size in MB from URL without downloading, handling multi-part files"""
196
+ try:
197
+ # Normalize the URL for direct access
198
+ normalized_url = normalize_huggingface_url(model_url)
199
+
200
+ # Get size of the main file
201
+ response = requests.head(normalized_url, allow_redirects=True)
202
+ response.raise_for_status()
203
+ main_file_size = int(response.headers.get('content-length', 0))
204
+
205
+ # Extract filename from original URL
206
+ filename = normalized_url.split('/')[-1]
207
+
208
+ # Check for multipart pattern (e.g., model-00001-of-00002.gguf)
209
+ match = re.match(r'(.+)-(\d+)-of-(\d+)\.gguf$', filename)
210
+
211
+ if match:
212
+ base_pattern = match.group(1)
213
+ total_parts = int(match.group(3))
214
+
215
+ total_size = 0
216
+ base_url = '/'.join(normalized_url.split('/')[:-1]) + '/'
217
+
218
+ # Get size of all parts
219
+ for part_num in range(1, total_parts + 1):
220
+ part_filename = f"{base_pattern}-{part_num:05d}-of-{total_parts:05d}.gguf"
221
+ part_url = base_url + part_filename
222
+
223
+ try:
224
+ part_response = requests.head(part_url, allow_redirects=True)
225
+ part_response.raise_for_status()
226
+ part_size = int(part_response.headers.get('content-length', 0))
227
+ total_size += part_size
228
+ except requests.RequestException as e:
229
+ print(f"Warning: Could not get size of {part_filename}, estimating...")
230
+ # If we can't get some parts, estimate based on what we have
231
+ if total_size > 0:
232
+ avg_size = total_size / (part_num - 1)
233
+ remaining_parts = total_parts - (part_num - 1)
234
+ total_size += avg_size * remaining_parts
235
+ else:
236
+ # Fallback to main file size * total parts
237
+ total_size = main_file_size * total_parts
238
+ break
239
+
240
+ return total_size / (1024 ** 2)
241
+ else:
242
+ # Single part file
243
+ return main_file_size / (1024 ** 2)
244
+
245
+ except Exception as e:
246
+ print(f"Error getting model size: {e}")
247
+ return 0.0
248
+
249
+
250
+ def estimate_vram(metadata, gpu_layers, ctx_size, cache_type):
251
+ """Calculate VRAM usage using the actual formula"""
252
+ try:
253
+ # Extract required values from metadata
254
+ n_layers = metadata.get('n_layers')
255
+ n_kv_heads = metadata.get('n_kv_heads')
256
+ embedding_dim = metadata.get('embedding_dim')
257
+ context_length = metadata.get('context_length')
258
+ feed_forward_dim = metadata.get('feed_forward_dim')
259
+ size_in_mb = metadata.get('model_size_mb', 0)
260
+
261
+ # Check if we have all required fields
262
+ required_fields = [n_layers, n_kv_heads, embedding_dim, context_length, feed_forward_dim]
263
+ if any(field is None for field in required_fields):
264
+ missing = [name for name, field in zip(
265
+ ['n_layers', 'n_kv_heads', 'embedding_dim', 'context_length', 'feed_forward_dim'],
266
+ required_fields) if field is None]
267
+ raise ValueError(f"Missing required metadata fields: {missing}")
268
+
269
+ # Ensure gpu_layers doesn't exceed total layers
270
+ if gpu_layers > n_layers:
271
+ gpu_layers = n_layers
272
+
273
+ # Convert cache_type to numeric
274
+ cache_type_map = {'fp16': 16, 'q8_0': 8, 'q4_0': 4}
275
+ cache_type_numeric = cache_type_map.get(cache_type, 16)
276
+
277
+ # Derived features
278
+ size_per_layer = size_in_mb / max(n_layers, 1e-6)
279
+ context_per_layer = context_length / max(n_layers, 1e-6)
280
+ ffn_per_embedding = feed_forward_dim / max(embedding_dim, 1e-6)
281
+ kv_cache_factor = n_kv_heads * cache_type_numeric * ctx_size
282
+
283
+ # Helper function for smaller
284
+ def smaller(x, y):
285
+ return 1 if x < y else 0
286
+
287
+ # Calculate VRAM using the model
288
+ vram = (
289
+ (size_per_layer - 21.19195204848197)
290
+ * exp(0.0001047328491557063 * size_in_mb * smaller(ffn_per_embedding, 2.671096993407845))
291
+ + 0.0006621544775632052 * context_per_layer
292
+ + 3.34664386576376e-05 * kv_cache_factor
293
+ ) * (1.363306170123392 + gpu_layers) + 1255.163594536052
294
+
295
+ return max(0, vram) # Ensure non-negative result
296
+
297
+ except Exception as e:
298
+ print(f"Error in VRAM calculation: {e}")
299
+ raise
300
+
301
+
302
+ def estimate_vram_wrapper(model_metadata, gpu_layers, ctx_size, cache_type):
303
+ """Wrapper function to estimate VRAM usage"""
304
+ if not model_metadata or 'model_name' not in model_metadata:
305
+ return "<div id=\"vram-info\">Estimated VRAM to load the model:</div>"
306
+
307
+ # Use cache_type directly (it's already a string from the radio button)
308
+ try:
309
+ result = estimate_vram(model_metadata, gpu_layers, ctx_size, cache_type)
310
+ conservative = result + 906
311
+ return f"""<div id="vram-info">
312
+ <div>Expected VRAM usage: <span class="value">{result:.0f} MiB</span></div>
313
+ <div>Safe estimate: <span class="value">{conservative:.0f} MiB</span> - 95% chance the VRAM is at most this.</div>
314
+ </div>"""
315
+ except Exception as e:
316
+ return f"<div id=\"vram-info\">Estimated VRAM to load the model: <span class=\"value\">Error: {str(e)}</span></div>"
317
+
318
+
319
+ def create_ui():
320
+ """Create the simplified UI"""
321
+ # Custom CSS to limit max width and center the content
322
+ css = """
323
+ body {
324
+ max-width: 810px !important;
325
+ margin: 0 auto !important;
326
+ }
327
+
328
+ #vram-info {
329
+ padding: 10px;
330
+ border-radius: 4px;
331
+ background-color: var(--background-fill-secondary);
332
+ }
333
+
334
+ #vram-info .value {
335
+ font-weight: bold;
336
+ color: var(--primary-500);
337
+ }
338
+ """
339
+
340
+ with gr.Blocks(css=css) as demo:
341
+ # State to hold model metadata
342
+ model_metadata = gr.State(value={})
343
+
344
+ gr.Markdown("# Accurage GGUF VRAM Calculator\n\nCalculate VRAM for GGUF models from GPU layers and context length using an accurate formula.\n\nFor an explanation about how this works, consult this blog post: https://oobabooga.github.io/blog/posts/gguf-vram-formula/")
345
+ with gr.Row():
346
+ with gr.Column():
347
+ # Model URL input
348
+ model_url = gr.Textbox(
349
+ label="GGUF Model URL",
350
+ placeholder="https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/blob/main/Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf",
351
+ value=""
352
+ )
353
+
354
+ # Load metadata button
355
+ load_metadata_btn = gr.Button("Load metadata", elem_classes='refresh-button')
356
+
357
+ # GPU layers slider
358
+ gpu_layers = gr.Slider(
359
+ label="GPU Layers",
360
+ minimum=0,
361
+ maximum=256,
362
+ value=256,
363
+ info='`--gpu-layers` in llama.cpp.'
364
+ )
365
+
366
+ # Context size slider
367
+ ctx_size = gr.Slider(
368
+ label='Context Length',
369
+ minimum=512,
370
+ maximum=131072,
371
+ step=256,
372
+ value=8192,
373
+ info='`--ctx-size` in llama.cpp.'
374
+ )
375
+
376
+ # Cache type checkbox group
377
+ cache_type = gr.Radio(
378
+ choices=['fp16', 'q8_0', 'q4_0'],
379
+ value='fp16',
380
+ label="Cache Type",
381
+ info='Cache quantization.'
382
+ )
383
+
384
+ # VRAM info display
385
+ vram_info = gr.HTML(
386
+ value="<div id=\"vram-info\">Estimated VRAM to load the model:</div>"
387
+ )
388
+
389
+ # Status display
390
+ status = gr.Textbox(
391
+ label="Status",
392
+ value="No model loaded",
393
+ interactive=False
394
+ )
395
+
396
+ # Event handlers
397
+ load_metadata_btn.click(
398
+ load_metadata,
399
+ inputs=[model_url, model_metadata],
400
+ outputs=[model_metadata, gpu_layers, status],
401
+ show_progress=True
402
+ ).then(
403
+ estimate_vram_wrapper,
404
+ inputs=[model_metadata, gpu_layers, ctx_size, cache_type],
405
+ outputs=[vram_info],
406
+ show_progress=False
407
+ )
408
+
409
+ # Update VRAM estimate when any parameter changes
410
+ for component in [gpu_layers, ctx_size, cache_type]:
411
+ component.change(
412
+ estimate_vram_wrapper,
413
+ inputs=[model_metadata, gpu_layers, ctx_size, cache_type],
414
+ outputs=[vram_info],
415
+ show_progress=False
416
+ )
417
+
418
+ # Also update when model_metadata state changes
419
+ model_metadata.change(
420
+ estimate_vram_wrapper,
421
+ inputs=[model_metadata, gpu_layers, ctx_size, cache_type],
422
+ outputs=[vram_info],
423
+ show_progress=False
424
+ )
425
+
426
+ return demo
427
+
428
+
429
+ if __name__ == "__main__":
430
+ # Create and launch the app
431
+ demo = create_ui()
432
+ demo.launch()