Phoenix21 commited on
Commit
9fde2fd
·
verified ·
1 Parent(s): 6c4ab66

used chain in classification chain.py

Browse files
Files changed (1) hide show
  1. classification_chain.py +5 -6
classification_chain.py CHANGED
@@ -2,10 +2,11 @@
2
  import os
3
  from langchain.chains import LLMChain
4
  from langchain_groq import ChatGroq
5
-
6
- # We'll import the classification_prompt from prompts.py
7
  from prompts import classification_prompt
8
 
 
9
  def get_classification_chain() -> LLMChain:
10
  """
11
  Builds the classification chain (LLMChain) using ChatGroq and the classification prompt.
@@ -15,10 +16,8 @@ def get_classification_chain() -> LLMChain:
15
  model="Gemma2-9b-It",
16
  groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment
17
  )
 
18
 
19
  # Build an LLMChain
20
- classification_chain = LLMChain(
21
- llm=chat_groq_model,
22
- prompt=classification_prompt
23
- )
24
  return classification_chain
 
2
  import os
3
  from langchain.chains import LLMChain
4
  from langchain_groq import ChatGroq
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import ChatPromptTemplate
7
  from prompts import classification_prompt
8
 
9
+
10
  def get_classification_chain() -> LLMChain:
11
  """
12
  Builds the classification chain (LLMChain) using ChatGroq and the classification prompt.
 
16
  model="Gemma2-9b-It",
17
  groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment
18
  )
19
+ prompt=
20
 
21
  # Build an LLMChain
22
+ classification_chain = classification_prompt|chat_groq_model|output_parser
 
 
 
23
  return classification_chain