|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_ollama import ChatOllama |
|
from pydantic import BaseModel, Field |
|
from typing import List |
|
|
|
class DocumentRelevance(BaseModel): |
|
"""Binary score for relevance check on retrieved documents.""" |
|
binary_score: str = Field( |
|
description="Documents are relevant to the question, 'yes' or 'no'" |
|
) |
|
|
|
class HallucinationCheck(BaseModel): |
|
"""Binary score for hallucination present in generation answer.""" |
|
binary_score: str = Field( |
|
description="Answer is grounded in the facts, 'yes' or 'no'" |
|
) |
|
|
|
class AnswerQuality(BaseModel): |
|
"""Binary score to assess answer addresses question.""" |
|
binary_score: str = Field( |
|
description="Answer addresses the question, 'yes' or 'no'" |
|
) |
|
|
|
def create_llm_grader(grader_type: str, llm): |
|
""" |
|
Create an LLM grader based on the specified type. |
|
|
|
Args: |
|
grader_type (str): Type of grader to create |
|
|
|
Returns: |
|
Callable: LLM grader function |
|
""" |
|
|
|
|
|
|
|
if grader_type == "document_relevance": |
|
structured_llm_grader = llm.with_structured_output(DocumentRelevance) |
|
system = """You are a grader assessing relevance of a retrieved document to a user question. |
|
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. |
|
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. |
|
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", system), |
|
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), |
|
]) |
|
|
|
elif grader_type == "hallucination": |
|
structured_llm_grader = llm.with_structured_output(HallucinationCheck) |
|
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. |
|
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", system), |
|
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"), |
|
]) |
|
|
|
elif grader_type == "answer_quality": |
|
structured_llm_grader = llm.with_structured_output(AnswerQuality) |
|
system = """You are a grader assessing whether an answer addresses / resolves a question. |
|
Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question.""" |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", system), |
|
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"), |
|
]) |
|
|
|
else: |
|
raise ValueError(f"Unknown grader type: {grader_type}") |
|
|
|
return prompt | structured_llm_grader |
|
|
|
def grade_document_relevance(question: str, document: str, llm): |
|
""" |
|
Grade the relevance of a document to a given question. |
|
|
|
Args: |
|
question (str): User's question |
|
document (str): Retrieved document content |
|
|
|
Returns: |
|
str: Binary score ('yes' or 'no') |
|
""" |
|
grader = create_llm_grader("document_relevance", llm) |
|
result = grader.invoke({"question": question, "document": document}) |
|
return result.binary_score |
|
|
|
def check_hallucination(documents: List[str], generation: str, llm): |
|
""" |
|
Check if the generation is grounded in the provided documents. |
|
|
|
Args: |
|
documents (List[str]): List of source documents |
|
generation (str): LLM generated answer |
|
|
|
Returns: |
|
str: Binary score ('yes' or 'no') |
|
""" |
|
grader = create_llm_grader("hallucination", llm) |
|
result = grader.invoke({"documents": documents, "generation": generation}) |
|
return result.binary_score |
|
|
|
def grade_answer_quality(question: str, generation: str, llm): |
|
""" |
|
Grade the quality of the answer in addressing the question. |
|
|
|
Args: |
|
question (str): User's original question |
|
generation (str): LLM generated answer |
|
|
|
Returns: |
|
str: Binary score ('yes' or 'no') |
|
""" |
|
grader = create_llm_grader("answer_quality", llm) |
|
result = grader.invoke({"question": question, "generation": generation}) |
|
return result.binary_score |
|
|
|
if __name__ == "__main__": |
|
|
|
test_question = "What are the types of agent memory?" |
|
test_document = "Agent memory can be classified into different types such as episodic, semantic, and working memory." |
|
test_generation = "Agent memory includes episodic memory for storing experiences, semantic memory for general knowledge, and working memory for immediate processing." |
|
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5) |
|
|
|
print("Document Relevance:", grade_document_relevance(test_question, test_document, llm)) |
|
print("Hallucination Check:", check_hallucination([test_document], test_generation, llm)) |
|
print("Answer Quality:", grade_answer_quality(test_question, test_generation, llm)) |