Gourisankar Padihary
commited on
Commit
·
c433668
1
Parent(s):
d93e32b
Model name update
Browse files- config.py +2 -0
- generator/initialize_llm.py +4 -2
config.py
CHANGED
@@ -4,6 +4,8 @@ class ConfigConstants:
|
|
4 |
DATA_SET_NAMES = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']
|
5 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
|
6 |
RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
|
|
|
|
|
7 |
DEFAULT_CHUNK_SIZE = 1000
|
8 |
CHUNK_OVERLAP = 200
|
9 |
|
|
|
4 |
DATA_SET_NAMES = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']
|
5 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2"
|
6 |
RE_RANKER_MODEL_NAME = 'cross-encoder/ms-marco-electra-base'
|
7 |
+
GENERATION_MODEL_NAME = 'mixtral-8x7b-32768'
|
8 |
+
VALIDATION_MODEL_NAME = 'llama3-70b-8192'
|
9 |
DEFAULT_CHUNK_SIZE = 1000
|
10 |
CHUNK_OVERLAP = 200
|
11 |
|
generator/initialize_llm.py
CHANGED
@@ -2,9 +2,11 @@ import logging
|
|
2 |
import os
|
3 |
from langchain_groq import ChatGroq
|
4 |
|
|
|
|
|
5 |
def initialize_generation_llm():
|
6 |
os.environ["GROQ_API_KEY"] = "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"
|
7 |
-
model_name =
|
8 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
9 |
llm.name = model_name
|
10 |
logging.info(f'Generation LLM {model_name} initialized')
|
@@ -12,7 +14,7 @@ def initialize_generation_llm():
|
|
12 |
|
13 |
def initialize_validation_llm():
|
14 |
os.environ["GROQ_API_KEY"] = "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"
|
15 |
-
model_name =
|
16 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
17 |
llm.name = model_name
|
18 |
logging.info(f'Validation LLM {model_name} initialized')
|
|
|
2 |
import os
|
3 |
from langchain_groq import ChatGroq
|
4 |
|
5 |
+
from config import ConfigConstants
|
6 |
+
|
7 |
def initialize_generation_llm():
|
8 |
os.environ["GROQ_API_KEY"] = "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"
|
9 |
+
model_name = ConfigConstants.GENERATION_MODEL_NAME
|
10 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
11 |
llm.name = model_name
|
12 |
logging.info(f'Generation LLM {model_name} initialized')
|
|
|
14 |
|
15 |
def initialize_validation_llm():
|
16 |
os.environ["GROQ_API_KEY"] = "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"
|
17 |
+
model_name = ConfigConstants.VALIDATION_MODEL_NAME
|
18 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
19 |
llm.name = model_name
|
20 |
logging.info(f'Validation LLM {model_name} initialized')
|