Spaces:
Runtime error
Runtime error
File size: 2,411 Bytes
dc2b56f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
import importlib.util
from transformers import AutoTokenizer, AutoModelForCausalLM
def load_model(model_path, hf_token):
"""
Load a pre-trained model and tokenizer, using unsloth if available,
falling back to standard transformers if necessary.
Args:
model_path (str): Path or identifier of the pre-trained model.
hf_token (str): Hugging Face API token for accessing gated models.
Returns:
tuple: Loaded model and tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token)
# Check if CUDA is available
cuda_available = torch.cuda.is_available()
if cuda_available:
print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
device = "cuda"
else:
print("CUDA is not available. Using CPU.")
device = "cpu"
# Try to use unsloth if it's available
if importlib.util.find_spec("unsloth") is not None:
try:
from unsloth import FastLanguageModel
print("Using unsloth for model loading.")
model, _ = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=2048,
dtype=None, # Automatically choose between float16 and bfloat16
load_in_4bit=cuda_available, # Only use 4-bit quantization if CUDA is available
token=hf_token
)
except Exception as e:
print(f"Error loading model with unsloth: {e}")
print("Falling back to standard transformers.")
model = load_with_transformers(model_path, hf_token, device)
else:
print("unsloth not found. Using standard transformers.")
model = load_with_transformers(model_path, hf_token, device)
# Do not use .to(device) for quantized models
# The device placement is handled automatically by unsloth or transformers
return model, tokenizer
def load_with_transformers(model_path, hf_token, device):
"""Helper function to load model with standard transformers library."""
return AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # This will handle device placement automatically
token=hf_token,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
) |