sibthinon commited on
Commit
6d417ec
·
verified ·
1 Parent(s): e14c099

use only E5 multilingual small

Browse files
Files changed (1) hide show
  1. app.py +53 -97
app.py CHANGED
@@ -14,64 +14,41 @@ import pickle
14
  import re
15
  import unicodedata
16
 
17
-
18
  qdrant_client = QdrantClient(
19
  url=os.environ.get("Qdrant_url"),
20
  api_key=os.environ.get("Qdrant_api"),
21
  )
22
 
 
23
  AIRTABLE_API_KEY = os.environ.get("airtable_api")
24
  BASE_ID = os.environ.get("airtable_baseid")
25
- TABLE_NAME = "Feedback_search" # หรือเปลี่ยนชื่อให้ชัดเช่น 'Feedback'
26
  api = Api(AIRTABLE_API_KEY)
27
  table = api.table(BASE_ID, TABLE_NAME)
28
 
29
- # โมเดลที่โหลดล่วงหน้า
30
- models = {
31
- "E5 (intfloat/multilingual-e5-small)": SentenceTransformer('intfloat/multilingual-e5-small'),
32
- "E5 large instruct (multilingual-e5-large-instruct)": SentenceTransformer("intfloat/multilingual-e5-large-instruct"),
33
- "Kalm (KaLM-embedding-multilingual-mini-v1)": SentenceTransformer('HIT-TMG/KaLM-embedding-multilingual-mini-v1')
34
- }
35
-
36
- model_config = {
37
- "E5 (intfloat/multilingual-e5-small)": {
38
- "func": lambda query: models["E5 (intfloat/multilingual-e5-small)"].encode("query: " + query),
39
- "collection": "product_E5",
40
- },
41
- "E5 large instruct (multilingual-e5-large-instruct)": {
42
- "func": lambda query: models["E5 large instruct (multilingual-e5-large-instruct)"].encode(
43
- "Instruct: Given a product search query, retrieve relevant product listings\nQuery: " + query, convert_to_tensor=False, normalize_embeddings=True),
44
- "collection": "product_E5_large_instruct",
45
- },
46
- "Kalm (KaLM-embedding-multilingual-mini-v1)": {
47
- "func": lambda query: models["Kalm (KaLM-embedding-multilingual-mini-v1)"].encode(query, normalize_embeddings=True),
48
- "collection": "product_kalm",
49
- }
50
- }
51
-
52
- # Global memory to hold feedback state
53
- latest_query_result = {"query": "", "result": "", "model": "", "raw_query": "", "time": ""}
54
 
 
55
  with open("keyword_whitelist.pkl", "rb") as f:
56
  keyword_whitelist = pickle.load(f)
57
 
 
58
  def normalize(text: str) -> str:
59
  text = unicodedata.normalize("NFC", text)
60
- text = text.replace("เแ", "แ").replace("เเ", "แ")
61
- return text.strip().lower()
62
 
63
  def smart_tokenize(text: str) -> list:
64
  tokens = word_tokenize(text.strip(), engine="newmm")
65
- if not tokens or len("".join(tokens)) < len(text.strip()) * 0.5:
66
- return [text.strip()]
67
- return tokens
68
 
69
  def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3):
70
  query_norm = normalize(query)
71
  tokens = smart_tokenize(query_norm)
72
  corrected = []
73
  i = 0
74
-
75
  while i < len(tokens):
76
  matched = False
77
  for n in range(min(max_ngram, len(tokens) - i), 0, -1):
@@ -85,22 +62,17 @@ def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3
85
  if not matched:
86
  corrected.append(tokens[i])
87
  i += 1
 
88
 
89
- # ตัดคำที่มีความยาว 1 ตัวอักษรและไม่ได้อยู่ใน whitelist
90
- cleaned = [word for word in corrected if len(word) > 1 or word in whitelist]
91
- return "".join(cleaned)
92
 
93
- # 🌟 Main search function
94
- def search_product(query, model_name):
95
  start_time = time.time()
96
- if model_name not in model_config:
97
- return "<p>❌ ไม่พบโมเดล</p>"
98
-
99
  latest_query_result["raw_query"] = query
100
- corrected_query = correct_query_merge_phrases(query,keyword_whitelist)
101
-
102
- query_embed = model_config[model_name]["func"](corrected_query)
103
- collection_name = model_config[model_name]["collection"]
104
 
105
  try:
106
  result = qdrant_client.query_points(
@@ -118,60 +90,52 @@ def search_product(query, model_name):
118
  if corrected_query != query:
119
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
120
 
121
- html_output += """
122
- <div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">
123
- """
124
 
125
- result_summary = ""
126
- found = False
127
  for res in result:
128
- if res.score > 0.8:
129
- found = True
130
- name = res.payload.get("name", "ไม่ทราบชื่อสินค้��")
131
- score = f"{res.score:.4f}"
132
- img_url = res.payload.get("imageUrl", "")
133
- price = res.payload.get("price", "ไม่ระบุ")
134
- brand = res.payload.get("brand", "")
135
-
136
- html_output += f"""
137
- <div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;">
138
- <img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;">
139
- <div style="margin-top: 10px;">
140
- <div style="font-weight: bold; font-size: 14px;">{name}</div>
141
- <div style="color: gray; font-size: 12px;">{brand}</div>
142
- <div style="color: green; margin: 4px 0;">฿{price}</div>
143
- <div style="font-size: 12px; color: #555;">score: {score}</div>
 
144
  </div>
145
- </div>
146
- """
147
- result_summary += f"{name} (score: {score}) | "
148
 
149
  html_output += "</div>"
150
 
151
  if not found:
152
- html_output += """
153
- <div style="text-align: center; font-size: 18px; color: #a00; padding: 30px;">
154
- ❌ ไม่พบสินค้าที่เกี่ยวข้องกับคำค้นนี้
155
- </div>
156
- """
157
- return html_output
158
-
159
- latest_query_result["query"] = corrected_query
160
- latest_query_result["result"] = result_summary.strip()
161
- latest_query_result["model"] = model_name
162
- latest_query_result["time"] = elapsed
163
 
164
- return html_output
 
 
 
 
165
 
 
166
 
167
- # 📝 Logging feedback
168
  def log_feedback(feedback):
169
  try:
170
  now = datetime.now().strftime("%Y-%m-%d")
171
  table.create({
172
  "timestamp": now,
173
  "raw_query": latest_query_result["raw_query"],
174
- "model": latest_query_result["model"],
175
  "query": latest_query_result["query"],
176
  "result": latest_query_result["result"],
177
  "time(second)": latest_query_result["time"],
@@ -181,20 +145,12 @@ def log_feedback(feedback):
181
  except Exception as e:
182
  return f"❌ Failed to save feedback: {str(e)}"
183
 
184
-
185
- # 🎨 Gradio UI
186
  with gr.Blocks() as demo:
187
  gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
188
 
189
- with gr.Row():
190
- model_selector = gr.Dropdown(
191
- choices=list(models.keys()),
192
- label="เลือกโมเดล",
193
- value="E5 (intfloat/multilingual-e5-small)"
194
- )
195
- query_input = gr.Textbox(label="พิมพ์คำค้นหา")
196
-
197
- result_output = gr.HTML(label="📋 ผลลัพธ์") # HTML แสดงผลลัพธ์พร้อมรูป
198
 
199
  with gr.Row():
200
  match_btn = gr.Button("✅ ตรง")
@@ -202,9 +158,9 @@ with gr.Blocks() as demo:
202
 
203
  feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
204
 
205
- query_input.submit(search_product, inputs=[query_input, model_selector], outputs=result_output)
206
  match_btn.click(lambda: log_feedback("match"), outputs=feedback_status)
207
  not_match_btn.click(lambda: log_feedback("not_match"), outputs=feedback_status)
208
 
209
- # Run app
210
- demo.launch(share=True)
 
14
  import re
15
  import unicodedata
16
 
17
+ # Setup Qdrant Client
18
  qdrant_client = QdrantClient(
19
  url=os.environ.get("Qdrant_url"),
20
  api_key=os.environ.get("Qdrant_api"),
21
  )
22
 
23
+ # Airtable Config
24
  AIRTABLE_API_KEY = os.environ.get("airtable_api")
25
  BASE_ID = os.environ.get("airtable_baseid")
26
+ TABLE_NAME = "Feedback_search"
27
  api = Api(AIRTABLE_API_KEY)
28
  table = api.table(BASE_ID, TABLE_NAME)
29
 
30
+ # Load model
31
+ model = SentenceTransformer('intfloat/multilingual-e5-small')
32
+ collection_name = "product_E5"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Load whitelist
35
  with open("keyword_whitelist.pkl", "rb") as f:
36
  keyword_whitelist = pickle.load(f)
37
 
38
+ # Utils
39
  def normalize(text: str) -> str:
40
  text = unicodedata.normalize("NFC", text)
41
+ return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower()
 
42
 
43
  def smart_tokenize(text: str) -> list:
44
  tokens = word_tokenize(text.strip(), engine="newmm")
45
+ return tokens if tokens and len("".join(tokens)) >= len(text.strip()) * 0.5 else [text.strip()]
 
 
46
 
47
  def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3):
48
  query_norm = normalize(query)
49
  tokens = smart_tokenize(query_norm)
50
  corrected = []
51
  i = 0
 
52
  while i < len(tokens):
53
  matched = False
54
  for n in range(min(max_ngram, len(tokens) - i), 0, -1):
 
62
  if not matched:
63
  corrected.append(tokens[i])
64
  i += 1
65
+ return "".join([word for word in corrected if len(word) > 1 or word in whitelist])
66
 
67
+ # Global state
68
+ latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
 
69
 
70
+ # Main Search
71
+ def search_product(query):
72
  start_time = time.time()
 
 
 
73
  latest_query_result["raw_query"] = query
74
+ corrected_query = correct_query_merge_phrases(query, keyword_whitelist)
75
+ query_embed = model.encode("query: " + corrected_query)
 
 
76
 
77
  try:
78
  result = qdrant_client.query_points(
 
90
  if corrected_query != query:
91
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
92
 
93
+ html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
 
 
94
 
95
+ result_summary, found = "", False
 
96
  for res in result:
97
+ if res.score > 0.8:
98
+ found = True
99
+ name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
100
+ score = f"{res.score:.4f}"
101
+ img_url = res.payload.get("imageUrl", "")
102
+ price = res.payload.get("price", "ไม่ระบุ")
103
+ brand = res.payload.get("brand", "")
104
+
105
+ html_output += f"""
106
+ <div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;">
107
+ <img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;">
108
+ <div style="margin-top: 10px;">
109
+ <div style="font-weight: bold; font-size: 14px;">{name}</div>
110
+ <div style="color: gray; font-size: 12px;">{brand}</div>
111
+ <div style="color: green; margin: 4px 0;">฿{price}</div>
112
+ <div style="font-size: 12px; color: #555;">score: {score}</div>
113
+ </div>
114
  </div>
115
+ """
116
+ result_summary += f"{name} (score: {score}) | "
 
117
 
118
  html_output += "</div>"
119
 
120
  if not found:
121
+ html_output += '<div style="text-align: center; font-size: 18px; color: #a00; padding: 30px;">❌ ไม่พบสินค้าที่เกี่ยวข้องกับคำค้นนี้</div>'
122
+ return html_output
 
 
 
 
 
 
 
 
 
123
 
124
+ latest_query_result.update({
125
+ "query": corrected_query,
126
+ "result": result_summary.strip(),
127
+ "time": elapsed,
128
+ })
129
 
130
+ return html_output
131
 
132
+ # Feedback logging
133
  def log_feedback(feedback):
134
  try:
135
  now = datetime.now().strftime("%Y-%m-%d")
136
  table.create({
137
  "timestamp": now,
138
  "raw_query": latest_query_result["raw_query"],
 
139
  "query": latest_query_result["query"],
140
  "result": latest_query_result["result"],
141
  "time(second)": latest_query_result["time"],
 
145
  except Exception as e:
146
  return f"❌ Failed to save feedback: {str(e)}"
147
 
148
+ # Gradio UI
 
149
  with gr.Blocks() as demo:
150
  gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
151
 
152
+ query_input = gr.Textbox(label="พิมพ์คำค้นหา")
153
+ result_output = gr.HTML(label="📋 ผลลัพธ์")
 
 
 
 
 
 
 
154
 
155
  with gr.Row():
156
  match_btn = gr.Button("✅ ตรง")
 
158
 
159
  feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
160
 
161
+ query_input.submit(search_product, inputs=[query_input], outputs=result_output)
162
  match_btn.click(lambda: log_feedback("match"), outputs=feedback_status)
163
  not_match_btn.click(lambda: log_feedback("not_match"), outputs=feedback_status)
164
 
165
+ # Run
166
+ demo.launch(share=True)