Wisdom Chen commited on
Commit
4784493
·
unverified ·
1 Parent(s): 3beffd8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +16 -9
model.py CHANGED
@@ -78,25 +78,32 @@ def initialize_models() -> bool:
78
  except Exception as e:
79
  raise RuntimeError(f"Failed to initialize CLIP model: {str(e)}")
80
 
81
- # Initialize LLM with CPU settings
82
  try:
83
  model_name = "mistralai/Mistral-7B-v0.1"
84
-
85
- # Initialize tokenizer
 
 
 
 
 
 
86
  llm_tokenizer = AutoTokenizer.from_pretrained(
87
  model_name,
88
- use_auth_token=hf_token,
89
  trust_remote_code=True
90
  )
91
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
92
 
93
- # Initialize model for CPU
94
  llm_model = AutoModelForCausalLM.from_pretrained(
95
  model_name,
96
- use_auth_token=hf_token,
97
- torch_dtype=torch.float32, # Use float32 for CPU
98
- trust_remote_code=True,
99
- low_cpu_mem_usage=True
 
 
100
  )
101
  llm_model.eval()
102
  print("LLM initialized successfully")
 
78
  except Exception as e:
79
  raise RuntimeError(f"Failed to initialize CLIP model: {str(e)}")
80
 
81
+ # Initialize LLM with optimized settings
82
  try:
83
  model_name = "mistralai/Mistral-7B-v0.1"
84
+ quantization_config = BitsAndBytesConfig(
85
+ load_in_4bit=True,
86
+ bnb_4bit_compute_dtype=torch.float16,
87
+ bnb_4bit_use_double_quant=True,
88
+ bnb_4bit_quant_type="nf4"
89
+ )
90
+
91
+ # Initialize tokenizer with specific version requirements
92
  llm_tokenizer = AutoTokenizer.from_pretrained(
93
  model_name,
94
+ use_auth_token=hf_token, # Changed from token to use_auth_token
95
  trust_remote_code=True
96
  )
97
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
98
 
 
99
  llm_model = AutoModelForCausalLM.from_pretrained(
100
  model_name,
101
+ use_auth_token=hf_token, # Changed from token to use_auth_token
102
+ quantization_config=quantization_config,
103
+ device_map='cpu', # Force CPU usage
104
+ torch_dtype=torch.float16,
105
+ low_cpu_mem_usage=True,
106
+ trust_remote_code=True
107
  )
108
  llm_model.eval()
109
  print("LLM initialized successfully")