adrienbrdne commited on
Commit
19491ad
·
verified ·
1 Parent(s): 8df620d

Update ki_gen/utils.py

Browse files
Files changed (1) hide show
  1. ki_gen/utils.py +21 -7
ki_gen/utils.py CHANGED
@@ -7,7 +7,10 @@ from typing import Annotated, Union
7
  from typing_extensions import TypedDict
8
 
9
  from langchain_community.graphs import Neo4jGraph
10
- from langchain_groq import ChatGroq
 
 
 
11
  from langchain_openai import ChatOpenAI
12
 
13
  from langgraph.checkpoint.sqlite import SqliteSaver
@@ -72,13 +75,17 @@ def _set_env(var: str, value: str = None):
72
  os.environ[var] = getpass.getpass(f"{var}: ")
73
 
74
 
75
- def init_app(openai_key : str = None, groq_key : str = None, langsmith_key : str = None):
 
76
  """
77
  Initialize app with user api keys and sets up proxy settings
78
  """
79
- _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key"))
 
80
  _set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
81
  _set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
 
 
82
  os.environ["LANGSMITH_TRACING_V2"] = "true"
83
  os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
84
 
@@ -93,14 +100,22 @@ def clear_memory(memory, thread_id: str = "") -> None:
93
  #checkpoint = base.empty_checkpoint()
94
  #memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
95
 
96
- def get_model(model : str = "deepseek-r1-distill-llama-70b"):
 
97
  """
98
  Wrapper to return the correct llm object depending on the 'model' param
99
  """
100
  if model == "gpt-4o":
101
  llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
 
 
 
 
102
  else:
103
- llm = ChatGroq(model=model)
 
 
 
104
  return llm
105
 
106
 
@@ -148,5 +163,4 @@ class DocProcessorState(TypedDict):
148
  valid_docs : list[Union[str, dict]]
149
  docs_in_processing : list
150
  process_steps : list[Union[str,dict]]
151
- current_process_step : int
152
-
 
7
  from typing_extensions import TypedDict
8
 
9
  from langchain_community.graphs import Neo4jGraph
10
+ # Remove ChatGroq import
11
+ # from langchain_groq import ChatGroq
12
+ # Add ChatGoogleGenerativeAI import
13
+ from langchain_google_genai import ChatGoogleGenerativeAI
14
  from langchain_openai import ChatOpenAI
15
 
16
  from langgraph.checkpoint.sqlite import SqliteSaver
 
75
  os.environ[var] = getpass.getpass(f"{var}: ")
76
 
77
 
78
+ # Remove groq_key parameter
79
+ def init_app(openai_key : str = None, langsmith_key : str = None):
80
  """
81
  Initialize app with user api keys and sets up proxy settings
82
  """
83
+ # Remove setting GROQ_API_KEY
84
+ # _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key"))
85
  _set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
86
  _set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
87
+ # Make sure GEMINI_API_KEY is set if needed elsewhere, though ChatGoogleGenerativeAI reads it automatically
88
+ _set_env("GEMINI_API_KEY", value=os.getenv("gemini_api_key"))
89
  os.environ["LANGSMITH_TRACING_V2"] = "true"
90
  os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
91
 
 
100
  #checkpoint = base.empty_checkpoint()
101
  #memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
102
 
103
+ # Update get_model to use ChatGoogleGenerativeAI
104
+ def get_model(model : str = "gemini-2.0-flash"):
105
  """
106
  Wrapper to return the correct llm object depending on the 'model' param
107
  """
108
  if model == "gpt-4o":
109
  llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
110
+ # Check for gemini models
111
+ elif model.startswith("gemini"):
112
+ # Pass the API key explicitly, although it often reads from env var by default
113
+ llm = ChatGoogleGenerativeAI(model=model, google_api_key=os.getenv("gemini_api_key"))
114
  else:
115
+ # Fallback or handle other models if necessary, maybe raise an error
116
+ # For now, defaulting to Gemini if model name doesn't match others
117
+ print(f"Warning: Model '{model}' not explicitly handled. Defaulting to Gemini.")
118
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("gemini_api_key"))
119
  return llm
120
 
121
 
 
163
  valid_docs : list[Union[str, dict]]
164
  docs_in_processing : list
165
  process_steps : list[Union[str,dict]]
166
+ current_process_step : int