Spaces:
Sleeping
Sleeping
| import logging | |
| import json | |
| from contextlib import asynccontextmanager | |
| from typing import Any, List, Tuple | |
| import random | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from FlagEmbedding import BGEM3FlagModel, FlagReranker | |
| from starlette.requests import Request | |
| import torch | |
| random.seed(42) | |
| logging.basicConfig() | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def get_data(model): | |
| with open("data/paris-2024-faq.json") as f: | |
| data = json.load(f) | |
| data = [it for it in data if it['lang'] == 'en'] | |
| questions = [it['label'] for it in data] | |
| q_embeddings = model[0].encode(questions, return_dense=False, return_sparse=False, return_colbert_vecs=True) | |
| return q_embeddings['colbert_vecs'], questions, [it['body'] for it in data] | |
| class InputLoad(BaseModel): | |
| question: str | |
| class ResponseLoad(BaseModel): | |
| answer: str | |
| class ML(BaseModel): | |
| retriever: Any | |
| ranker: Any | |
| data: Tuple[List[Any], List[str], List[str]] | |
| def load_models(app: FastAPI) -> FastAPI: | |
| retriever=BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) , | |
| ranker=FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) | |
| ml = ML( | |
| retriever=retriever, | |
| ranker=ranker, | |
| data=get_data(retriever) | |
| ) | |
| app.ml = ml | |
| return app | |
| async def lifespan(app: FastAPI): | |
| app = load_models(app=app) | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| def health_check(): | |
| return {"server": "running"} | |
| async def receive(input_load: InputLoad, request: Request) -> ResponseLoad: | |
| ml: ML = request.app.ml | |
| candidate_indices, candidate_scores = get_candidates(input_load.question, ml) | |
| answer_candidate, rank_score, retriever_score = rerank_candidates(input_load.question, candidate_indices, candidate_scores, ml) | |
| answer = get_final_answer(answer_candidate, retriever_score) | |
| return ResponseLoad(answer=answer) | |
| def get_candidates(question, ml, topk=5): | |
| question_emb = ml.retriever[0].encode([question], return_dense=False, return_sparse=False, return_colbert_vecs=True) | |
| question_emb = question_emb['colbert_vecs'][0] | |
| scores = [ml.retriever[0].colbert_score(question_emb, faq_emb) for faq_emb in ml.data[0]] | |
| scores_tensor = torch.stack(scores) | |
| top_values, top_indices = torch.topk(scores_tensor, topk) | |
| return top_indices.tolist(), top_values.tolist() | |
| def rerank_candidates(question, indices, values, ml): | |
| candidate_answers = [ml.data[2][_ind] for _ind in indices] | |
| scores = ml.ranker.compute_score([[question, it] for it in candidate_answers]) | |
| rank_score = max(scores) | |
| rank_ind = scores.index(rank_score) | |
| retriever_score = values[rank_ind] | |
| return candidate_answers[rank_ind], rank_score, retriever_score | |
| def get_final_answer(answer, retriever_score): | |
| logger.info(f"Retriever score: {retriever_score}") | |
| if retriever_score < 0.65: | |
| # nothing relevant found! | |
| return random.sample(NOT_FOUND_ANSWERS, k=1)[0] | |
| elif retriever_score < 0.8: | |
| # might be relevant, but let's be careful | |
| return f"{random.sample(ROUGH_MATCH_INTROS, k=1)[0]}\n{answer}" | |
| else: | |
| # good match | |
| return f"{random.sample(GOOD_MATCH_INTROS, k=1)[0]}\n{answer}\n{random.sample(GOOD_MATCH_ENDS, k=1)[0]}" | |
| NOT_FOUND_ANSWERS = [ | |
| "I'm sorry, but I couldn't find any information related to your question in my knowledge base.", | |
| "Apologies, but I don't have the information you're looking for at the moment.", | |
| "I’m sorry, I couldn’t locate any relevant details in my current data.", | |
| "Unfortunately, I wasn't able to find an answer to your query. Can I help with something else?", | |
| "I'm afraid I don't have the information you need right now. Please feel free to ask another question.", | |
| "Sorry, I couldn't find anything that matches your question in my knowledge base.", | |
| "I apologize, but I wasn't able to retrieve information related to your query.", | |
| "I'm sorry, but it looks like I don't have an answer for that. Is there anything else I can assist with?", | |
| "Regrettably, I couldn't find the information you requested. Can I help you with anything else?", | |
| "I’m sorry, but I don't have the details you're seeking in my knowledge database." | |
| ] | |
| GOOD_MATCH_INTROS = ["Super!"] | |
| GOOD_MATCH_ENDS = ["Hopes this helps!"] | |
| ROUGH_MATCH_INTROS = ["Not sure if that answers your question!"] | |