Spaces:
Sleeping
Sleeping
import logging | |
import json | |
from contextlib import asynccontextmanager | |
from typing import Any, List, Tuple | |
import random | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from FlagEmbedding import BGEM3FlagModel, FlagReranker | |
from starlette.requests import Request | |
import torch | |
random.seed(42) | |
logging.basicConfig() | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
def get_data(model): | |
with open("data/paris-2024-faq.json") as f: | |
data = json.load(f) | |
data = [it for it in data if it['lang'] == 'en'] | |
questions = [it['label'] for it in data] | |
q_embeddings = model[0].encode(questions, return_dense=False, return_sparse=False, return_colbert_vecs=True) | |
return q_embeddings['colbert_vecs'], questions, [it['body'] for it in data] | |
class InputLoad(BaseModel): | |
question: str | |
class ResponseLoad(BaseModel): | |
answer: str | |
class ML(BaseModel): | |
retriever: Any | |
ranker: Any | |
data: Tuple[List[Any], List[str], List[str]] | |
def load_models(app: FastAPI) -> FastAPI: | |
retriever=BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) , | |
ranker=FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) | |
ml = ML( | |
retriever=retriever, | |
ranker=ranker, | |
data=get_data(retriever) | |
) | |
app.ml = ml | |
return app | |
async def lifespan(app: FastAPI): | |
app = load_models(app=app) | |
yield | |
app = FastAPI(lifespan=lifespan) | |
def health_check(): | |
return {"server": "running"} | |
async def receive(input_load: InputLoad, request: Request) -> ResponseLoad: | |
ml: ML = request.app.ml | |
candidate_indices, candidate_scores = get_candidates(input_load.question, ml) | |
answer_candidate, rank_score, retriever_score = rerank_candidates(input_load.question, candidate_indices, candidate_scores, ml) | |
answer = get_final_answer(answer_candidate, retriever_score) | |
return ResponseLoad(answer=answer) | |
def get_candidates(question, ml, topk=5): | |
question_emb = ml.retriever[0].encode([question], return_dense=False, return_sparse=False, return_colbert_vecs=True) | |
question_emb = question_emb['colbert_vecs'][0] | |
scores = [ml.retriever[0].colbert_score(question_emb, faq_emb) for faq_emb in ml.data[0]] | |
scores_tensor = torch.stack(scores) | |
top_values, top_indices = torch.topk(scores_tensor, topk) | |
return top_indices.tolist(), top_values.tolist() | |
def rerank_candidates(question, indices, values, ml): | |
candidate_answers = [ml.data[2][_ind] for _ind in indices] | |
scores = ml.ranker.compute_score([[question, it] for it in candidate_answers]) | |
rank_score = max(scores) | |
rank_ind = scores.index(rank_score) | |
retriever_score = values[rank_ind] | |
return candidate_answers[rank_ind], rank_score, retriever_score | |
def get_final_answer(answer, retriever_score): | |
logger.info(f"Retriever score: {retriever_score}") | |
if retriever_score < 0.65: | |
# nothing relevant found! | |
return random.sample(NOT_FOUND_ANSWERS, k=1)[0] | |
elif retriever_score < 0.8: | |
# might be relevant, but let's be careful | |
return f"{random.sample(ROUGH_MATCH_INTROS, k=1)[0]}\n{answer}" | |
else: | |
# good match | |
return f"{random.sample(GOOD_MATCH_INTROS, k=1)[0]}\n{answer}\n{random.sample(GOOD_MATCH_ENDS, k=1)[0]}" | |
NOT_FOUND_ANSWERS = [ | |
"I'm sorry, but I couldn't find any information related to your question in my knowledge base.", | |
"Apologies, but I don't have the information you're looking for at the moment.", | |
"I’m sorry, I couldn’t locate any relevant details in my current data.", | |
"Unfortunately, I wasn't able to find an answer to your query. Can I help with something else?", | |
"I'm afraid I don't have the information you need right now. Please feel free to ask another question.", | |
"Sorry, I couldn't find anything that matches your question in my knowledge base.", | |
"I apologize, but I wasn't able to retrieve information related to your query.", | |
"I'm sorry, but it looks like I don't have an answer for that. Is there anything else I can assist with?", | |
"Regrettably, I couldn't find the information you requested. Can I help you with anything else?", | |
"I’m sorry, but I don't have the details you're seeking in my knowledge database." | |
] | |
GOOD_MATCH_INTROS = ["Super!"] | |
GOOD_MATCH_ENDS = ["Hopes this helps!"] | |
ROUGH_MATCH_INTROS = ["Not sure if that answers your question!"] | |