|
from llama_index.core.evaluation import RetrieverEvaluator |
|
from llama_index.llms.ollama import Ollama |
|
from llama_index.core import Settings |
|
from utils.recursive_retrieve import get_bm25_recursive_retriever, get_hybrid_recursive_retriever, get_recursive_retriever, prepare_nodes |
|
import nest_asyncio, os, asyncio |
|
|
|
nest_asyncio.apply() |
|
|
|
llm = Ollama("pornchat", request_timeout=120) |
|
Settings.llm = llm |
|
|
|
data_dir = "/data1/home/purui/projects/chatbot/data/txt" |
|
index_dir = "/data1/home/purui/projects/chatbot/kb" |
|
|
|
bm25_retriever = get_bm25_recursive_retriever(data_dir, index_dir) |
|
hybrid_retriever = get_hybrid_recursive_retriever(data_dir, index_dir) |
|
basic_retriever = get_recursive_retriever(data_dir, index_dir) |
|
|
|
data_types = [] |
|
all_nodes = [] |
|
for dir in os.listdir(data_dir): |
|
if os.path.isdir(f"{data_dir}/{dir}"): |
|
data_types.append(dir) |
|
for data_type in data_types: |
|
nodes = prepare_nodes(f"{data_dir}/{data_type}", index_dir, data_type=data_type) |
|
all_nodes.extend(nodes) |
|
|
|
|
|
|
|
|
|
|
|
retriever_evaluator = RetrieverEvaluator.from_metric_names( |
|
["mrr", "hit_rate"], retriever=bm25_retriever |
|
) |
|
|
|
from llama_index.core.evaluation import generate_question_context_pairs |
|
|
|
qa_dataset = generate_question_context_pairs( |
|
nodes, llm=llm, num_questions_per_chunk=2 |
|
) |
|
|
|
async def eval(retriever_evaluator: RetrieverEvaluator, qa_dataset): |
|
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset, show_progress=True) |
|
return eval_results |
|
|
|
result = asyncio.run(eval(retriever_evaluator, qa_dataset)) |
|
print(result) |