| # classification_chain.py | |
| import os | |
| from langchain.chains import LLMChain | |
| from langchain_groq import ChatGroq | |
| from prompts import classification_prompt | |
| # classification_chain.py | |
| def get_classification_chain() -> LLMChain: | |
| """ | |
| Builds the classification chain (LLMChain) using ChatGroq and the classification prompt. | |
| """ | |
| # Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY | |
| chat_groq_model = ChatGroq( | |
| model="Gemma2-9b-It", | |
| groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment | |
| ) | |
| # Build an LLMChain | |
| classification_chain = LLMChain( | |
| llm=chat_groq_model, | |
| prompt=classification_prompt | |
| ) | |
| return classification_chain | |
| def classify_with_history(query: str, chat_history: list) -> str: | |
| """ | |
| Classifies a user query based on the context of previous conversation (chat_history). | |
| """ | |
| # Add the history into the query context if needed (depending on the type of model) | |
| context = "\n".join([f"User: {msg['content']}" for msg in chat_history]) + "\nUser: " + query | |
| # Update the prompt with both the context and the query | |
| classification_result = get_classification_chain().run({"query": context}) | |
| return classification_result | |