Wisdom Chen commited on
Commit
fbbacc5
·
unverified ·
1 Parent(s): 9e65556

Update model.py

Browse files

update initialized model function

Files changed (1) hide show
  1. model.py +6 -5
model.py CHANGED
@@ -80,20 +80,23 @@ def initialize_models() -> bool:
80
  if not hf_token:
81
  raise ValueError("HUGGINGFACE_TOKEN not found in Streamlit secrets")
82
 
 
83
  llm_tokenizer = AutoTokenizer.from_pretrained(
84
  model_name,
85
  token=hf_token,
86
- padding_side="left",
87
- truncation_side="left"
88
  )
89
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
90
 
 
91
  llm_model = AutoModelForCausalLM.from_pretrained(
92
  model_name,
93
  token=hf_token,
94
  quantization_config=quantization_config,
95
  device_map="auto",
96
- torch_dtype=torch.float16
 
97
  )
98
  llm_model.eval()
99
  print("LLM initialized successfully")
@@ -636,7 +639,6 @@ def hybrid_retrieval(query, top_k=5):
636
 
637
  return results, query_type
638
 
639
-
640
  def fallback_text_search(query, top_k=10):
641
  relevant_products = filter_by_metadata(query, metadata)
642
  if not relevant_products:
@@ -757,7 +759,6 @@ def chatbot(query, image_input=None):
757
  print(f"Error processing query: {str(e)}")
758
  return f"Error processing request: {str(e)}"
759
 
760
-
761
  def cleanup_resources():
762
  if torch.cuda.is_available():
763
  torch.cuda.empty_cache()
 
80
  if not hf_token:
81
  raise ValueError("HUGGINGFACE_TOKEN not found in Streamlit secrets")
82
 
83
+ # Initialize tokenizer with trust_remote_code=True
84
  llm_tokenizer = AutoTokenizer.from_pretrained(
85
  model_name,
86
  token=hf_token,
87
+ trust_remote_code=True,
88
+ use_fast=True
89
  )
90
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
91
 
92
+ # Initialize model with trust_remote_code=True
93
  llm_model = AutoModelForCausalLM.from_pretrained(
94
  model_name,
95
  token=hf_token,
96
  quantization_config=quantization_config,
97
  device_map="auto",
98
+ torch_dtype=torch.float16,
99
+ trust_remote_code=True
100
  )
101
  llm_model.eval()
102
  print("LLM initialized successfully")
 
639
 
640
  return results, query_type
641
 
 
642
  def fallback_text_search(query, top_k=10):
643
  relevant_products = filter_by_metadata(query, metadata)
644
  if not relevant_products:
 
759
  print(f"Error processing query: {str(e)}")
760
  return f"Error processing request: {str(e)}"
761
 
 
762
  def cleanup_resources():
763
  if torch.cuda.is_available():
764
  torch.cuda.empty_cache()