Wisdom Chen commited on
Commit
4f3be3c
·
unverified ·
1 Parent(s): 838a59a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +49 -51
model.py CHANGED
@@ -1,5 +1,4 @@
1
  # Standard libraries
2
- import streamlit as st
3
  import os
4
  import io
5
  import json
@@ -25,7 +24,7 @@ from transformers import (
25
  PreTrainedModel,
26
  PreTrainedTokenizer
27
  )
28
- from huggingface_hub import hf_hub_download, login
29
  from langchain.prompts import PromptTemplate
30
 
31
  # Vector database
@@ -48,6 +47,12 @@ text_faiss: Optional[object] = None
48
  image_faiss: Optional[object] = None
49
 
50
  def initialize_models() -> bool:
 
 
 
 
 
 
51
  global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
52
 
53
  try:
@@ -58,8 +63,6 @@ def initialize_models() -> bool:
58
  clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
59
  'hf-hub:Marqo/marqo-fashionCLIP'
60
  )
61
- # Use to_empty() first, then move to device
62
- clip_model = clip_model.to_empty(device=device)
63
  clip_model = clip_model.to(device)
64
  clip_model.eval()
65
  clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP')
@@ -77,14 +80,10 @@ def initialize_models() -> bool:
77
  bnb_4bit_quant_type="nf4"
78
  )
79
 
80
- # Get token from Streamlit secrets
81
- hf_token = st.secrets["HUGGINGFACE_TOKEN"]
82
-
83
  llm_tokenizer = AutoTokenizer.from_pretrained(
84
  model_name,
85
  padding_side="left",
86
- truncation_side="left",
87
- token=hf_token # Add token here
88
  )
89
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
90
 
@@ -92,8 +91,7 @@ def initialize_models() -> bool:
92
  model_name,
93
  quantization_config=quantization_config,
94
  device_map="auto",
95
- torch_dtype=torch.float16,
96
- token=hf_token # Add token here
97
  )
98
  llm_model.eval()
99
  print("LLM initialized successfully")
@@ -104,45 +102,6 @@ def initialize_models() -> bool:
104
 
105
  except Exception as e:
106
  raise RuntimeError(f"Model initialization failed: {str(e)}")
107
-
108
- def load_embeddings_from_huggingface(repo_id: str) -> Tuple[Dict, Dict]:
109
- """
110
- Load embeddings from Hugging Face repository with enhanced error handling.
111
-
112
- Args:
113
- repo_id (str): Hugging Face repository ID
114
-
115
- Returns:
116
- Tuple[Dict, Dict]: Dictionaries containing text and image embeddings
117
- """
118
- print("Loading embeddings from Hugging Face...")
119
- try:
120
- file_path = hf_hub_download(
121
- repo_id=repo_id,
122
- filename="embeddings.parquet",
123
- repo_type="dataset"
124
- )
125
- df = pd.read_parquet(file_path)
126
-
127
- # Extract embedding columns
128
- text_cols = [col for col in df.columns if col.startswith('text_embedding_')]
129
- image_cols = [col for col in df.columns if col.startswith('image_embedding_')]
130
-
131
- # Create embedding dictionaries
132
- text_embeddings_dict = {
133
- row['Uniq_Id']: row[text_cols].values.astype(np.float32)
134
- for _, row in df.iterrows()
135
- }
136
- image_embeddings_dict = {
137
- row['Uniq_Id']: row[image_cols].values.astype(np.float32)
138
- for _, row in df.iterrows()
139
- }
140
-
141
- print(f"Successfully loaded {len(text_embeddings_dict)} embeddings")
142
- return text_embeddings_dict, image_embeddings_dict
143
-
144
- except Exception as e:
145
- raise RuntimeError(f"Failed to load embeddings from Hugging Face: {str(e)}")
146
 
147
  # Data loading
148
  def load_data() -> bool:
@@ -301,6 +260,45 @@ def load_data() -> bool:
301
  image_faiss = None
302
  raise RuntimeError(f"Data loading failed: {str(e)}")
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  # FAISS index creation
305
  class MultiModalFAISSIndex:
306
  def __init__(self, dimension, index_type='L2'):
@@ -598,7 +596,7 @@ def classify_query(query):
598
  return 'image_search'
599
  else:
600
  return 'product_info'
601
-
602
  def boost_category_relevance(query, product, similarity_score):
603
  query_terms = set(query.lower().split())
604
  category_terms = set(product['Category'].lower().split())
 
1
  # Standard libraries
 
2
  import os
3
  import io
4
  import json
 
24
  PreTrainedModel,
25
  PreTrainedTokenizer
26
  )
27
+ from huggingface_hub import hf_hub_download
28
  from langchain.prompts import PromptTemplate
29
 
30
  # Vector database
 
47
  image_faiss: Optional[object] = None
48
 
49
  def initialize_models() -> bool:
50
+ """
51
+ Initialize CLIP and LLM models with proper error handling and GPU optimization.
52
+
53
+ Returns:
54
+ bool: True if initialization successful, raises RuntimeError otherwise
55
+ """
56
  global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
57
 
58
  try:
 
63
  clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
64
  'hf-hub:Marqo/marqo-fashionCLIP'
65
  )
 
 
66
  clip_model = clip_model.to(device)
67
  clip_model.eval()
68
  clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP')
 
80
  bnb_4bit_quant_type="nf4"
81
  )
82
 
 
 
 
83
  llm_tokenizer = AutoTokenizer.from_pretrained(
84
  model_name,
85
  padding_side="left",
86
+ truncation_side="left"
 
87
  )
88
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
89
 
 
91
  model_name,
92
  quantization_config=quantization_config,
93
  device_map="auto",
94
+ torch_dtype=torch.float16
 
95
  )
96
  llm_model.eval()
97
  print("LLM initialized successfully")
 
102
 
103
  except Exception as e:
104
  raise RuntimeError(f"Model initialization failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Data loading
107
  def load_data() -> bool:
 
260
  image_faiss = None
261
  raise RuntimeError(f"Data loading failed: {str(e)}")
262
 
263
+ def load_embeddings_from_huggingface(repo_id: str) -> Tuple[Dict, Dict]:
264
+ """
265
+ Load embeddings from Hugging Face repository with enhanced error handling.
266
+
267
+ Args:
268
+ repo_id (str): Hugging Face repository ID
269
+
270
+ Returns:
271
+ Tuple[Dict, Dict]: Dictionaries containing text and image embeddings
272
+ """
273
+ print("Loading embeddings from Hugging Face...")
274
+ try:
275
+ file_path = hf_hub_download(
276
+ repo_id=repo_id,
277
+ filename="embeddings.parquet",
278
+ repo_type="dataset"
279
+ )
280
+ df = pd.read_parquet(file_path)
281
+
282
+ # Extract embedding columns
283
+ text_cols = [col for col in df.columns if col.startswith('text_embedding_')]
284
+ image_cols = [col for col in df.columns if col.startswith('image_embedding_')]
285
+
286
+ # Create embedding dictionaries
287
+ text_embeddings_dict = {
288
+ row['Uniq_Id']: row[text_cols].values.astype(np.float32)
289
+ for _, row in df.iterrows()
290
+ }
291
+ image_embeddings_dict = {
292
+ row['Uniq_Id']: row[image_cols].values.astype(np.float32)
293
+ for _, row in df.iterrows()
294
+ }
295
+
296
+ print(f"Successfully loaded {len(text_embeddings_dict)} embeddings")
297
+ return text_embeddings_dict, image_embeddings_dict
298
+
299
+ except Exception as e:
300
+ raise RuntimeError(f"Failed to load embeddings from Hugging Face: {str(e)}")
301
+
302
  # FAISS index creation
303
  class MultiModalFAISSIndex:
304
  def __init__(self, dimension, index_type='L2'):
 
596
  return 'image_search'
597
  else:
598
  return 'product_info'
599
+
600
  def boost_category_relevance(query, product, similarity_score):
601
  query_terms = set(query.lower().split())
602
  category_terms = set(product['Category'].lower().split())