open_llm_leaderboard / backend /app /utils /model_validation.py
alozowski's picture
alozowski HF staff
Correct gptq approach [wip]
fb82c68
raw
history blame
15.1 kB
import json
import logging
import asyncio
from typing import Tuple, Optional, Dict, Any
from datasets import load_dataset
from huggingface_hub import HfApi, ModelCard, hf_hub_download
from huggingface_hub import hf_api
from transformers import AutoConfig, AutoTokenizer
from app.config.base import HF_TOKEN
from app.config.hf_config import OFFICIAL_PROVIDERS_REPO
from app.core.formatting import LogFormatter
logger = logging.getLogger(__name__)
class ModelValidator:
def __init__(self):
self.token = HF_TOKEN
self.api = HfApi(token=self.token)
self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
self.logger = logger
self.config_cache = {}
async def check_model_card(self, model_id: str) -> Tuple[bool, str, Optional[Dict[str, Any]]]:
"""Check if model has a valid model card"""
try:
logger.info(LogFormatter.info(f"Checking model card for {model_id}"))
# Get model card content using ModelCard.load
try:
model_card = await asyncio.to_thread(
ModelCard.load,
model_id
)
logger.info(LogFormatter.success("Model card found"))
except Exception as e:
error_msg = "Please add a model card to your model to explain how you trained/fine-tuned it."
logger.error(LogFormatter.error(error_msg, e))
return False, error_msg, None
# Check license in model card data
if model_card.data.license is None and not ("license_name" in model_card.data and "license_link" in model_card.data):
error_msg = "License not found. Please add a license to your model card using the `license` metadata or a `license_name`/`license_link` pair."
logger.warning(LogFormatter.warning(error_msg))
return False, error_msg, None
# Enforce card content length
if len(model_card.text) < 200:
error_msg = "Please add a description to your model card, it is too short."
logger.warning(LogFormatter.warning(error_msg))
return False, error_msg, None
logger.info(LogFormatter.success("Model card validation passed"))
return True, "", model_card
except Exception as e:
error_msg = "Failed to validate model card"
logger.error(LogFormatter.error(error_msg, e))
return False, str(e), None
async def get_safetensors_metadata(self, model_id: str, is_adapter: bool = False, revision: str = "main") -> Optional[Dict]:
"""Get metadata from a safetensors file"""
try:
if is_adapter:
metadata = await asyncio.to_thread(
hf_api.parse_safetensors_file_metadata,
model_id,
"adapter_model.safetensors",
token=self.token,
revision=revision,
)
else:
metadata = await asyncio.to_thread(
hf_api.get_safetensors_metadata,
repo_id=model_id,
token=self.token,
revision=revision,
)
return metadata
except Exception as e:
logger.error(f"Failed to get safetensors metadata: {str(e)}")
return None
async def get_model_size(
self,
model_info: Any,
precision: str,
base_model: str,
revision: str
) -> Tuple[Optional[float], Optional[str]]:
try:
self.logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}"))
# Check if model is adapter
is_adapter = any(
s.rfilename == "adapter_config.json"
for s in model_info.siblings
if hasattr(s, 'rfilename')
)
# Get model size from safetensors
model_size = None
if is_adapter and base_model:
# For adapters, combine adapter and base model sizes
adapter_meta = await self.get_safetensors_metadata(
model_info.id,
is_adapter=True,
revision=revision
)
base_meta = await self.get_safetensors_metadata(
base_model,
revision="main"
)
if adapter_meta and base_meta:
adapter_size = sum(adapter_meta.parameter_count.values())
base_size = sum(base_meta.parameter_count.values())
model_size = adapter_size + base_size
else:
# For regular models
meta = await self.get_safetensors_metadata(
model_info.id,
revision=revision
)
if meta:
model_size = sum(meta.parameter_count.values())
if model_size is None:
return None, "Model size could not be determined"
if model_size <= 0:
return None, "Invalid model size: must be positive"
# Only proceed with GPTQ adjustments if necessary
if precision == "GPTQ" or "gptq" in model_info.id.lower():
precision_bits = await self._get_precision_bits(
model_info.id,
revision
)
if precision_bits is None:
return None, "Failed to determine precision bits"
# FIXED: We should divide by the size factor since quantization reduces size
size_factor = precision_bits / 32 # For 2-bit this is 2/32 = 1/16
self.logger.info(LogFormatter.info(
f"Applying quantization factor: {size_factor}x (bits={precision_bits})"
))
else:
size_factor = 1
# Convert to billions and apply quantization factor
model_size = model_size / 1e9 # Convert to billions
model_size = round(size_factor * model_size, 3)
self.logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
return model_size, None
except Exception as e:
self.logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
return None, str(e)
async def _get_precision_bits(
self,
model_id: str,
revision: str
) -> Optional[int]:
"""Get the precision bits from config.json, with caching."""
# Check cache first
cache_key = f"{model_id}_{revision}"
if cache_key in self.config_cache:
config_data = self.config_cache[cache_key]
else:
# Fetch config.json
config_url = f"https://huggingface.co/{model_id}/raw/{revision}/config.json"
try:
async with aiohttp.ClientSession() as session:
async with session.get(config_url, headers=self.headers) as response:
if response.status != 200:
self.logger.warning(LogFormatter.warning(
f"Failed to fetch config.json from {config_url}. Defaulting to 4 bits for GPTQ."
))
return 4
# Try to parse response as JSON regardless of content type
try:
text = await response.text()
config_data = json.loads(text)
self.config_cache[cache_key] = config_data
except json.JSONDecodeError:
self.logger.warning(LogFormatter.warning(
f"Failed to parse config.json from {config_url}. Defaulting to 4 bits for GPTQ."
))
return 4
except Exception as e:
self.logger.error(LogFormatter.error(
f"Error fetching config.json: {e}. Defaulting to 4 bits."
))
return 4
# Get precision bits from config
try:
precision_bits = config_data.get("quantization_config", {}).get("bits", 4)
# Validate precision bits
if precision_bits not in [2, 3, 4, 8]:
self.logger.error(LogFormatter.error(
f"Unsupported precision_bits: {precision_bits}"
))
return None
return precision_bits
except Exception as e:
self.logger.error(LogFormatter.error(
f"Error extracting precision bits from config: {e}. Defaulting to 4 bits."
))
return 4
async def check_chat_template(
self,
model_id: str,
revision: str
) -> Tuple[bool, Optional[str]]:
"""Check if model has a valid chat template"""
try:
logger.info(LogFormatter.info(f"Checking chat template for {model_id}"))
try:
config_file = await asyncio.to_thread(
hf_hub_download,
repo_id=model_id,
filename="tokenizer_config.json",
revision=revision,
repo_type="model"
)
with open(config_file, 'r') as f:
tokenizer_config = json.load(f)
if 'chat_template' not in tokenizer_config:
error_msg = f"The model {model_id} doesn't have a chat_template in its tokenizer_config.json. Please add a chat_template before submitting or submit without it."
logger.error(LogFormatter.error(error_msg))
return False, error_msg
logger.info(LogFormatter.success("Valid chat template found"))
return True, None
except Exception as e:
error_msg = f"Error checking chat_template: {str(e)}"
logger.error(LogFormatter.error(error_msg))
return False, error_msg
except Exception as e:
error_msg = "Failed to check chat template"
logger.error(LogFormatter.error(error_msg, e))
return False, str(e)
async def is_model_on_hub(
self,
model_name: str,
revision: str,
test_tokenizer: bool = False,
trust_remote_code: bool = False
) -> Tuple[bool, Optional[str], Optional[Any]]:
"""Check if model exists and is properly configured on the Hub"""
try:
config = await asyncio.to_thread(
AutoConfig.from_pretrained,
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
token=self.token,
force_download=True
)
if test_tokenizer:
try:
await asyncio.to_thread(
AutoTokenizer.from_pretrained,
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
token=self.token
)
except ValueError as e:
return False, f"The tokenizer is not available in an official Transformers release: {e}", None
except Exception:
return False, "The tokenizer cannot be loaded. Ensure the tokenizer class is part of a stable Transformers release and correctly configured.", None
return True, None, config
except ValueError:
return False, "The model requires `trust_remote_code=True` to launch, and for safety reasons, we don't accept such models automatically.", None
except Exception as e:
if "You are trying to access a gated repo." in str(e):
return True, "The model is gated and requires special access permissions.", None
return False, f"The model was not found or is misconfigured on the Hub. Error: {e.args[0]}", None
async def check_official_provider_status(
self,
model_id: str,
existing_models: Dict[str, list]
) -> Tuple[bool, Optional[str]]:
"""
Check if model is from official provider and has finished submission.
Args:
model_id: The model identifier (org/model-name)
existing_models: Dictionary of models by status from get_models()
Returns:
Tuple[bool, Optional[str]]: (is_valid, error_message)
"""
try:
logger.info(LogFormatter.info(f"Checking official provider status for {model_id}"))
# Get model organization
model_org = model_id.split('/')[0] if '/' in model_id else None
if not model_org:
return True, None
# Load official providers dataset
dataset = load_dataset(OFFICIAL_PROVIDERS_REPO)
official_providers = dataset["train"][0]["CURATED_SET"]
# Check if model org is in official providers
is_official = model_org in official_providers
if is_official:
logger.info(LogFormatter.info(f"Model organization '{model_org}' is an official provider"))
# Check for finished submissions
if "finished" in existing_models:
for model in existing_models["finished"]:
if model["name"] == model_id:
error_msg = (
f"Model {model_id} is an official provider model "
f"with a completed evaluation. "
f"To re-evaluate, please open a discussion."
)
logger.error(LogFormatter.error("Validation failed", error_msg))
return False, error_msg
logger.info(LogFormatter.success("No finished submission found for this official provider model"))
else:
logger.info(LogFormatter.info(f"Model organization '{model_org}' is not an official provider"))
return True, None
except Exception as e:
error_msg = f"Failed to check official provider status: {str(e)}"
logger.error(LogFormatter.error(error_msg))
return False, error_msg