Gourisankar Padihary commited on
Commit
c433668
·
1 Parent(s): d93e32b

Model name update

Browse files
Files changed (2) hide show
  1. config.py +2 -0
  2. generator/initialize_llm.py +4 -2
config.py CHANGED
@@ -4,6 +4,8 @@ class ConfigConstants:
4
  DATA_SET_NAMES = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']
5
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
  RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
 
 
7
  DEFAULT_CHUNK_SIZE = 1000
8
  CHUNK_OVERLAP = 200
9
 
 
4
  DATA_SET_NAMES = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']
5
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
6
  RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
7
+ GENERATION_MODEL_NAME = 'mixtral-8x7b-32768'
8
+ VALIDATION_MODEL_NAME = 'llama3-70b-8192'
9
  DEFAULT_CHUNK_SIZE = 1000
10
  CHUNK_OVERLAP = 200
11
 
generator/initialize_llm.py CHANGED
@@ -2,9 +2,11 @@ import logging
2
  import os
3
  from langchain_groq import ChatGroq
4
 
 
 
5
  def initialize_generation_llm():
6
  os.environ["GROQ_API_KEY"] = "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"
7
- model_name = "mixtral-8x7b-32768"
8
  llm = ChatGroq(model=model_name, temperature=0.7)
9
  llm.name = model_name
10
  logging.info(f'Generation LLM {model_name} initialized')
@@ -12,7 +14,7 @@ def initialize_generation_llm():
12
 
13
  def initialize_validation_llm():
14
  os.environ["GROQ_API_KEY"] = "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"
15
- model_name = "llama3-70b-8192"
16
  llm = ChatGroq(model=model_name, temperature=0.7)
17
  llm.name = model_name
18
  logging.info(f'Validation LLM {model_name} initialized')
 
2
  import os
3
  from langchain_groq import ChatGroq
4
 
5
+ from config import ConfigConstants
6
+
7
  def initialize_generation_llm():
8
  os.environ["GROQ_API_KEY"] = "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"
9
+ model_name = ConfigConstants.GENERATION_MODEL_NAME
10
  llm = ChatGroq(model=model_name, temperature=0.7)
11
  llm.name = model_name
12
  logging.info(f'Generation LLM {model_name} initialized')
 
14
 
15
  def initialize_validation_llm():
16
  os.environ["GROQ_API_KEY"] = "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"
17
+ model_name = ConfigConstants.VALIDATION_MODEL_NAME
18
  llm = ChatGroq(model=model_name, temperature=0.7)
19
  llm.name = model_name
20
  logging.info(f'Validation LLM {model_name} initialized')