File size: 1,797 Bytes
db17bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langchain_core.output_parsers import StrOutputParser

def create_query_rewriter(llm):
    """
    Create a query rewriter to optimize retrieval.
    
    Returns:
        Callable: Query rewriter function
    """
    
    # Prompt for query rewriting
    system = """You are a question re-writer that converts an input question to a better version that is optimized 
    for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
    
    re_write_prompt = ChatPromptTemplate.from_messages([
        ("system", system),
        ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
    ])
    
    # Create query rewriter chain
    return re_write_prompt | llm | StrOutputParser()

def rewrite_query(question: str, llm):
    """
    Rewrite a given query to optimize retrieval.
    
    Args:
        question (str): Original user question
    
    Returns:
        str: Rewritten query
    """
    query_rewriter = create_query_rewriter(llm)
    try:
        rewritten_query = query_rewriter.invoke({"question": question})
        return rewritten_query
    except Exception as e:
        print(f"Query rewriting error: {e}")
        return question

if __name__ == "__main__":
    # Example usage
    test_queries = [
        "Tell me about AI agents",
        "What do we know about memory in AI systems?",
        "Bears draft strategy"
    ]
    llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
    
    for query in test_queries:
        rewritten = rewrite_query(query, llm)
        print(f"Original: {query}")
        print(f"Rewritten: {rewritten}\n")