SexBot / utils /eval.py
Pew404's picture
Upload folder using huggingface_hub
318db6e verified
from llama_index.core.evaluation import RetrieverEvaluator
from llama_index.llms.ollama import Ollama
from llama_index.core import Settings
from utils.recursive_retrieve import get_bm25_recursive_retriever, get_hybrid_recursive_retriever, get_recursive_retriever, prepare_nodes
import nest_asyncio, os, asyncio
nest_asyncio.apply()
llm = Ollama("pornchat", request_timeout=120)
Settings.llm = llm
data_dir = "/data1/home/purui/projects/chatbot/data/txt"
index_dir = "/data1/home/purui/projects/chatbot/kb"
bm25_retriever = get_bm25_recursive_retriever(data_dir, index_dir)
hybrid_retriever = get_hybrid_recursive_retriever(data_dir, index_dir)
basic_retriever = get_recursive_retriever(data_dir, index_dir)
data_types = []
all_nodes = []
for dir in os.listdir(data_dir):
if os.path.isdir(f"{data_dir}/{dir}"):
data_types.append(dir)
for data_type in data_types:
nodes = prepare_nodes(f"{data_dir}/{data_type}", index_dir, data_type=data_type)
all_nodes.extend(nodes)
# define retriever somewhere (e.g. from index)
# retriever = index.as_retriever(similarity_top_k=2)
# retriever = ...
retriever_evaluator = RetrieverEvaluator.from_metric_names(
["mrr", "hit_rate"], retriever=bm25_retriever
)
from llama_index.core.evaluation import generate_question_context_pairs
qa_dataset = generate_question_context_pairs(
nodes, llm=llm, num_questions_per_chunk=2
)
async def eval(retriever_evaluator: RetrieverEvaluator, qa_dataset):
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset, show_progress=True)
return eval_results
result = asyncio.run(eval(retriever_evaluator, qa_dataset))
print(result)