Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| from sentence_transformers import util | |
| from typing import Tuple, List, Dict, Any, Optional | |
| from app.services.model_service import get_model, get_embeddings | |
| def search_answer( | |
| user_input: str, | |
| model, | |
| question_embeddings: np.ndarray, | |
| answer_embeddings: np.ndarray, | |
| threshold_q: float, | |
| threshold_a: float, | |
| answers: List[str], | |
| ) -> Tuple[str, str]: | |
| """ | |
| Search for an answer using cosine similarity. | |
| """ | |
| # Encode with batch_size and show_progress_bar=False to speed up | |
| user_embedding = model.encode( | |
| [user_input], | |
| convert_to_numpy=True, | |
| batch_size=1, | |
| show_progress_bar=False, | |
| normalize_embeddings=True, # Pre-normalize to speed up cosine similarity | |
| ) | |
| # Calculate cosine similarity with questions | |
| cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0] | |
| best_q_idx = np.argmax(cos_scores_q) | |
| score_q = cos_scores_q[best_q_idx] | |
| if score_q >= threshold_q: | |
| return ( | |
| answers[best_q_idx].replace("\n", " \n"), | |
| f"{score_q:.2f}", | |
| ) | |
| # Calculate cosine similarity with answers | |
| cos_scores_a = util.cos_sim(user_embedding, answer_embeddings)[0] | |
| best_a_idx = np.argmax(cos_scores_a) | |
| score_a = cos_scores_a[best_a_idx] | |
| if score_a >= threshold_a: | |
| return ( | |
| answers[best_a_idx].replace("\n", " \n"), | |
| f"{score_a:.2f}", | |
| ) | |
| return ( | |
| "็ณใ่จณใใใพใใใใใ่ณชๅใฎ็ญใใ่ฆใคใใใใจใใงใใพใใใงใใใใใๅฐใ่ฉณใใ่ชฌๆใใฆใใใ ใใพใใ๏ผ", | |
| "ไธ่ดใชใ", | |
| ) | |
| def predict_answer( | |
| user_input: str, threshold_q: float = 0.5, threshold_a: float = 0.5 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Predict an answer based on user input. | |
| """ | |
| try: | |
| # Get the global model and embeddings | |
| model = get_model() | |
| question_embeddings, answer_embeddings, qa_data = get_embeddings() | |
| if question_embeddings is None or answer_embeddings is None or qa_data is None: | |
| return { | |
| "status": "error", | |
| "message": "Embeddings not found. Please create embeddings first.", | |
| } | |
| answers = [item["answer"] for item in qa_data] | |
| answer, score = search_answer( | |
| user_input, | |
| model, | |
| question_embeddings, | |
| answer_embeddings, | |
| threshold_q, | |
| threshold_a, | |
| answers, | |
| ) | |
| return {"status": "success", "answer": answer, "score": score} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |