Spaces:
Build error
Build error
Wisdom Chen
commited on
Update model.py
Browse files
model.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
# Standard libraries
|
2 |
-
import streamlit as st
|
3 |
import os
|
4 |
import io
|
5 |
import json
|
@@ -25,7 +24,7 @@ from transformers import (
|
|
25 |
PreTrainedModel,
|
26 |
PreTrainedTokenizer
|
27 |
)
|
28 |
-
from huggingface_hub import hf_hub_download
|
29 |
from langchain.prompts import PromptTemplate
|
30 |
|
31 |
# Vector database
|
@@ -48,6 +47,12 @@ text_faiss: Optional[object] = None
|
|
48 |
image_faiss: Optional[object] = None
|
49 |
|
50 |
def initialize_models() -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
|
52 |
|
53 |
try:
|
@@ -58,8 +63,6 @@ def initialize_models() -> bool:
|
|
58 |
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
|
59 |
'hf-hub:Marqo/marqo-fashionCLIP'
|
60 |
)
|
61 |
-
# Use to_empty() first, then move to device
|
62 |
-
clip_model = clip_model.to_empty(device=device)
|
63 |
clip_model = clip_model.to(device)
|
64 |
clip_model.eval()
|
65 |
clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP')
|
@@ -77,14 +80,10 @@ def initialize_models() -> bool:
|
|
77 |
bnb_4bit_quant_type="nf4"
|
78 |
)
|
79 |
|
80 |
-
# Get token from Streamlit secrets
|
81 |
-
hf_token = st.secrets["HUGGINGFACE_TOKEN"]
|
82 |
-
|
83 |
llm_tokenizer = AutoTokenizer.from_pretrained(
|
84 |
model_name,
|
85 |
padding_side="left",
|
86 |
-
truncation_side="left"
|
87 |
-
token=hf_token # Add token here
|
88 |
)
|
89 |
llm_tokenizer.pad_token = llm_tokenizer.eos_token
|
90 |
|
@@ -92,8 +91,7 @@ def initialize_models() -> bool:
|
|
92 |
model_name,
|
93 |
quantization_config=quantization_config,
|
94 |
device_map="auto",
|
95 |
-
torch_dtype=torch.float16
|
96 |
-
token=hf_token # Add token here
|
97 |
)
|
98 |
llm_model.eval()
|
99 |
print("LLM initialized successfully")
|
@@ -104,45 +102,6 @@ def initialize_models() -> bool:
|
|
104 |
|
105 |
except Exception as e:
|
106 |
raise RuntimeError(f"Model initialization failed: {str(e)}")
|
107 |
-
|
108 |
-
def load_embeddings_from_huggingface(repo_id: str) -> Tuple[Dict, Dict]:
|
109 |
-
"""
|
110 |
-
Load embeddings from Hugging Face repository with enhanced error handling.
|
111 |
-
|
112 |
-
Args:
|
113 |
-
repo_id (str): Hugging Face repository ID
|
114 |
-
|
115 |
-
Returns:
|
116 |
-
Tuple[Dict, Dict]: Dictionaries containing text and image embeddings
|
117 |
-
"""
|
118 |
-
print("Loading embeddings from Hugging Face...")
|
119 |
-
try:
|
120 |
-
file_path = hf_hub_download(
|
121 |
-
repo_id=repo_id,
|
122 |
-
filename="embeddings.parquet",
|
123 |
-
repo_type="dataset"
|
124 |
-
)
|
125 |
-
df = pd.read_parquet(file_path)
|
126 |
-
|
127 |
-
# Extract embedding columns
|
128 |
-
text_cols = [col for col in df.columns if col.startswith('text_embedding_')]
|
129 |
-
image_cols = [col for col in df.columns if col.startswith('image_embedding_')]
|
130 |
-
|
131 |
-
# Create embedding dictionaries
|
132 |
-
text_embeddings_dict = {
|
133 |
-
row['Uniq_Id']: row[text_cols].values.astype(np.float32)
|
134 |
-
for _, row in df.iterrows()
|
135 |
-
}
|
136 |
-
image_embeddings_dict = {
|
137 |
-
row['Uniq_Id']: row[image_cols].values.astype(np.float32)
|
138 |
-
for _, row in df.iterrows()
|
139 |
-
}
|
140 |
-
|
141 |
-
print(f"Successfully loaded {len(text_embeddings_dict)} embeddings")
|
142 |
-
return text_embeddings_dict, image_embeddings_dict
|
143 |
-
|
144 |
-
except Exception as e:
|
145 |
-
raise RuntimeError(f"Failed to load embeddings from Hugging Face: {str(e)}")
|
146 |
|
147 |
# Data loading
|
148 |
def load_data() -> bool:
|
@@ -301,6 +260,45 @@ def load_data() -> bool:
|
|
301 |
image_faiss = None
|
302 |
raise RuntimeError(f"Data loading failed: {str(e)}")
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
# FAISS index creation
|
305 |
class MultiModalFAISSIndex:
|
306 |
def __init__(self, dimension, index_type='L2'):
|
@@ -598,7 +596,7 @@ def classify_query(query):
|
|
598 |
return 'image_search'
|
599 |
else:
|
600 |
return 'product_info'
|
601 |
-
|
602 |
def boost_category_relevance(query, product, similarity_score):
|
603 |
query_terms = set(query.lower().split())
|
604 |
category_terms = set(product['Category'].lower().split())
|
|
|
1 |
# Standard libraries
|
|
|
2 |
import os
|
3 |
import io
|
4 |
import json
|
|
|
24 |
PreTrainedModel,
|
25 |
PreTrainedTokenizer
|
26 |
)
|
27 |
+
from huggingface_hub import hf_hub_download
|
28 |
from langchain.prompts import PromptTemplate
|
29 |
|
30 |
# Vector database
|
|
|
47 |
image_faiss: Optional[object] = None
|
48 |
|
49 |
def initialize_models() -> bool:
|
50 |
+
"""
|
51 |
+
Initialize CLIP and LLM models with proper error handling and GPU optimization.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
bool: True if initialization successful, raises RuntimeError otherwise
|
55 |
+
"""
|
56 |
global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
|
57 |
|
58 |
try:
|
|
|
63 |
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
|
64 |
'hf-hub:Marqo/marqo-fashionCLIP'
|
65 |
)
|
|
|
|
|
66 |
clip_model = clip_model.to(device)
|
67 |
clip_model.eval()
|
68 |
clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP')
|
|
|
80 |
bnb_4bit_quant_type="nf4"
|
81 |
)
|
82 |
|
|
|
|
|
|
|
83 |
llm_tokenizer = AutoTokenizer.from_pretrained(
|
84 |
model_name,
|
85 |
padding_side="left",
|
86 |
+
truncation_side="left"
|
|
|
87 |
)
|
88 |
llm_tokenizer.pad_token = llm_tokenizer.eos_token
|
89 |
|
|
|
91 |
model_name,
|
92 |
quantization_config=quantization_config,
|
93 |
device_map="auto",
|
94 |
+
torch_dtype=torch.float16
|
|
|
95 |
)
|
96 |
llm_model.eval()
|
97 |
print("LLM initialized successfully")
|
|
|
102 |
|
103 |
except Exception as e:
|
104 |
raise RuntimeError(f"Model initialization failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Data loading
|
107 |
def load_data() -> bool:
|
|
|
260 |
image_faiss = None
|
261 |
raise RuntimeError(f"Data loading failed: {str(e)}")
|
262 |
|
263 |
+
def load_embeddings_from_huggingface(repo_id: str) -> Tuple[Dict, Dict]:
|
264 |
+
"""
|
265 |
+
Load embeddings from Hugging Face repository with enhanced error handling.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
repo_id (str): Hugging Face repository ID
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
Tuple[Dict, Dict]: Dictionaries containing text and image embeddings
|
272 |
+
"""
|
273 |
+
print("Loading embeddings from Hugging Face...")
|
274 |
+
try:
|
275 |
+
file_path = hf_hub_download(
|
276 |
+
repo_id=repo_id,
|
277 |
+
filename="embeddings.parquet",
|
278 |
+
repo_type="dataset"
|
279 |
+
)
|
280 |
+
df = pd.read_parquet(file_path)
|
281 |
+
|
282 |
+
# Extract embedding columns
|
283 |
+
text_cols = [col for col in df.columns if col.startswith('text_embedding_')]
|
284 |
+
image_cols = [col for col in df.columns if col.startswith('image_embedding_')]
|
285 |
+
|
286 |
+
# Create embedding dictionaries
|
287 |
+
text_embeddings_dict = {
|
288 |
+
row['Uniq_Id']: row[text_cols].values.astype(np.float32)
|
289 |
+
for _, row in df.iterrows()
|
290 |
+
}
|
291 |
+
image_embeddings_dict = {
|
292 |
+
row['Uniq_Id']: row[image_cols].values.astype(np.float32)
|
293 |
+
for _, row in df.iterrows()
|
294 |
+
}
|
295 |
+
|
296 |
+
print(f"Successfully loaded {len(text_embeddings_dict)} embeddings")
|
297 |
+
return text_embeddings_dict, image_embeddings_dict
|
298 |
+
|
299 |
+
except Exception as e:
|
300 |
+
raise RuntimeError(f"Failed to load embeddings from Hugging Face: {str(e)}")
|
301 |
+
|
302 |
# FAISS index creation
|
303 |
class MultiModalFAISSIndex:
|
304 |
def __init__(self, dimension, index_type='L2'):
|
|
|
596 |
return 'image_search'
|
597 |
else:
|
598 |
return 'product_info'
|
599 |
+
|
600 |
def boost_category_relevance(query, product, similarity_score):
|
601 |
query_terms = set(query.lower().split())
|
602 |
category_terms = set(product['Category'].lower().split())
|