hf_extractor / model_utils.py
dwb2023's picture
update imports to support model architectures
7ed52f5 verified
raw
history blame
3.21 kB
import subprocess
import os
import torch
from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration, Owlv2ForObjectDetection, GroundingDinoForObjectDetection, SamModel, NomicBertModel
import spaces
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def install_flash_attn():
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
ARCHITECTURE_MAP = {
"LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
"LlavaForConditionalGeneration": LlavaForConditionalGeneration,
"PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
"Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
"Owlv2ForObjectDetection": Owlv2ForObjectDetection,
"GroundingDinoForObjectDetection": GroundingDinoForObjectDetection,
"SamModel": SamModel,
"AutoModelForCausalLM": AutoModelForCausalLM,
"NomicBertModel": NomicBertModel,
}
@spaces.GPU
def get_model_summary(model_name):
try:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
architecture = config.architectures[0]
quantization_config = getattr(config, 'quantization_config', None)
if quantization_config:
bnb_config = BitsAndBytesConfig(
load_in_4bit=quantization_config.get('load_in_4bit', False),
load_in_8bit=quantization_config.get('load_in_8bit', False),
bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
)
else:
bnb_config = None
model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
model = model_class.from_pretrained(
model_name, config=bnb_config, trust_remote_code=True
)
if model and not quantization_config:
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model_summary = str(model) if model else "Model architecture not found."
config_content = config.to_json_string() if config else "Configuration not found."
return f"## Model Architecture\n\n{model_summary}\n\n## Configuration\n\n{config_content}", ""
except ValueError as ve:
return "", f"ValueError: {ve}"
except EnvironmentError as ee:
return "", f"EnvironmentError: {ee}"
except Exception as e:
return "", str(e)