realtime-rag-pipeline / generator /compute_rmse_auc_roc_metrics.py
Gourisankar Padihary
Compute RMSE and AUCROC
f7c2fa3
raw
history blame
4.69 kB
from sklearn.metrics import roc_auc_score, root_mean_squared_error
from generator.compute_metrics import get_metrics
from generator.extract_attributes import extract_attributes
from generator.generate_response import generate_response
from retriever.retrieve_documents import retrieve_top_k_documents
def compute_rmse_auc_roc_metrics(llm, dataset, vector_store):
# Lists to accumulate ground truths and predictions for AUC-ROC computation
all_ground_truth_relevance = []
all_predicted_relevance = []
all_ground_truth_utilization = []
all_predicted_utilization = []
all_ground_truth_adherence = []
all_predicted_adherence = []
# To store RMSE scores for each question
relevance_scores = []
utilization_scores = []
adherence_scores = []
for i, sample in enumerate(dataset):
print(sample)
sample_question = sample['question']
# Extract ground truth metrics from dataset
ground_truth_relevance = dataset[i]['relevance_score']
ground_truth_utilization = dataset[i]['utilization_score']
ground_truth_completeness = dataset[i]['completeness_score']
# Step 1: Retrieve relevant documents
relevant_docs = retrieve_top_k_documents(vector_store, sample_question, top_k=5)
# Step 2: Generate a response using LLM
response, source_docs = generate_response(llm, vector_store, sample_question, relevant_docs)
# Step 3: Extract attributes
attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
# Call the process_attributes method in the main block
metrics = get_metrics(attributes, total_sentences)
# Extract predicted metrics (ensure these are continuous if possible)
predicted_relevance = metrics['Context Relevance']
predicted_utilization = metrics['Context Utilization']
predicted_completeness = metrics['Completeness Score']
# === Handle Continuous Inputs for RMSE ===
relevance_rmse = root_mean_squared_error([ground_truth_relevance], [predicted_relevance])
utilization_rmse = root_mean_squared_error([ground_truth_utilization], [predicted_utilization])
#adherence_rmse = mean_squared_error([ground_truth_adherence], [predicted_adherence], squared=False)
# === Handle Binary Conversion for AUC-ROC ===
binary_ground_truth_relevance = 1 if ground_truth_relevance > 0.5 else 0
binary_predicted_relevance = 1 if predicted_relevance > 0.5 else 0
binary_ground_truth_utilization = 1 if ground_truth_utilization > 0.5 else 0
binary_predicted_utilization = 1 if predicted_utilization > 0.5 else 0
#binary_ground_truth_adherence = 1 if ground_truth_adherence > 0.5 else 0
#binary_predicted_adherence = 1 if predicted_adherence > 0.5 else 0
# === Accumulate data for overall AUC-ROC computation ===
all_ground_truth_relevance.append(binary_ground_truth_relevance)
all_predicted_relevance.append(predicted_relevance) # Use probability-based predictions
all_ground_truth_utilization.append(binary_ground_truth_utilization)
all_predicted_utilization.append(predicted_utilization)
#all_ground_truth_adherence.append(binary_ground_truth_adherence)
#all_predicted_adherence.append(predicted_adherence)
# Store RMSE scores for each question
relevance_scores.append(relevance_rmse)
utilization_scores.append(utilization_rmse)
#adherence_scores.append(adherence_rmse)
if i == 9: # Stop after processing the first 10 rows
break
# === Compute AUC-ROC for the Entire Dataset ===
try:
print(f"All Ground Truth Relevance: {all_ground_truth_relevance}")
print(f"All Predicted Relevance: {all_predicted_relevance}")
relevance_auc = roc_auc_score(all_ground_truth_relevance, all_predicted_relevance)
except ValueError:
relevance_auc = None
try:
print(f"All Ground Truth Utilization: {all_ground_truth_utilization}")
print(f"All Predicted Utilization: {all_predicted_utilization}")
utilization_auc = roc_auc_score(all_ground_truth_utilization, all_predicted_utilization)
except ValueError:
utilization_auc = None
print(f"Relevance RMSE (per question): {relevance_scores}")
print(f"Utilization RMSE (per question): {utilization_scores}")
#print(f"Adherence RMSE (per question): {adherence_scores}")
print(f"\nOverall Relevance AUC-ROC: {relevance_auc}")
print(f"Overall Utilization AUC-ROC: {utilization_auc}")