Spaces:
Sleeping
Sleeping
Update ki_gen/utils.py
Browse files- ki_gen/utils.py +21 -7
ki_gen/utils.py
CHANGED
@@ -7,7 +7,10 @@ from typing import Annotated, Union
|
|
7 |
from typing_extensions import TypedDict
|
8 |
|
9 |
from langchain_community.graphs import Neo4jGraph
|
10 |
-
|
|
|
|
|
|
|
11 |
from langchain_openai import ChatOpenAI
|
12 |
|
13 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
@@ -72,13 +75,17 @@ def _set_env(var: str, value: str = None):
|
|
72 |
os.environ[var] = getpass.getpass(f"{var}: ")
|
73 |
|
74 |
|
75 |
-
|
|
|
76 |
"""
|
77 |
Initialize app with user api keys and sets up proxy settings
|
78 |
"""
|
79 |
-
|
|
|
80 |
_set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
|
81 |
_set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
|
|
|
|
|
82 |
os.environ["LANGSMITH_TRACING_V2"] = "true"
|
83 |
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
|
84 |
|
@@ -93,14 +100,22 @@ def clear_memory(memory, thread_id: str = "") -> None:
|
|
93 |
#checkpoint = base.empty_checkpoint()
|
94 |
#memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
|
95 |
|
96 |
-
|
|
|
97 |
"""
|
98 |
Wrapper to return the correct llm object depending on the 'model' param
|
99 |
"""
|
100 |
if model == "gpt-4o":
|
101 |
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
|
|
|
|
|
|
|
|
|
102 |
else:
|
103 |
-
|
|
|
|
|
|
|
104 |
return llm
|
105 |
|
106 |
|
@@ -148,5 +163,4 @@ class DocProcessorState(TypedDict):
|
|
148 |
valid_docs : list[Union[str, dict]]
|
149 |
docs_in_processing : list
|
150 |
process_steps : list[Union[str,dict]]
|
151 |
-
current_process_step : int
|
152 |
-
|
|
|
7 |
from typing_extensions import TypedDict
|
8 |
|
9 |
from langchain_community.graphs import Neo4jGraph
|
10 |
+
# Remove ChatGroq import
|
11 |
+
# from langchain_groq import ChatGroq
|
12 |
+
# Add ChatGoogleGenerativeAI import
|
13 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
14 |
from langchain_openai import ChatOpenAI
|
15 |
|
16 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
|
75 |
os.environ[var] = getpass.getpass(f"{var}: ")
|
76 |
|
77 |
|
78 |
+
# Remove groq_key parameter
|
79 |
+
def init_app(openai_key : str = None, langsmith_key : str = None):
|
80 |
"""
|
81 |
Initialize app with user api keys and sets up proxy settings
|
82 |
"""
|
83 |
+
# Remove setting GROQ_API_KEY
|
84 |
+
# _set_env("GROQ_API_KEY", value=os.getenv("groq_api_key"))
|
85 |
_set_env("LANGSMITH_API_KEY", value=os.getenv("langsmith_api_key"))
|
86 |
_set_env("OPENAI_API_KEY", value=os.getenv("openai_api_key"))
|
87 |
+
# Make sure GEMINI_API_KEY is set if needed elsewhere, though ChatGoogleGenerativeAI reads it automatically
|
88 |
+
_set_env("GEMINI_API_KEY", value=os.getenv("gemini_api_key"))
|
89 |
os.environ["LANGSMITH_TRACING_V2"] = "true"
|
90 |
os.environ["LANGCHAIN_PROJECT"] = "3GPP Test"
|
91 |
|
|
|
100 |
#checkpoint = base.empty_checkpoint()
|
101 |
#memory.put(config={"configurable": {"thread_id": thread_id}}, checkpoint=checkpoint, metadata={})
|
102 |
|
103 |
+
# Update get_model to use ChatGoogleGenerativeAI
|
104 |
+
def get_model(model : str = "gemini-2.0-flash"):
|
105 |
"""
|
106 |
Wrapper to return the correct llm object depending on the 'model' param
|
107 |
"""
|
108 |
if model == "gpt-4o":
|
109 |
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
|
110 |
+
# Check for gemini models
|
111 |
+
elif model.startswith("gemini"):
|
112 |
+
# Pass the API key explicitly, although it often reads from env var by default
|
113 |
+
llm = ChatGoogleGenerativeAI(model=model, google_api_key=os.getenv("gemini_api_key"))
|
114 |
else:
|
115 |
+
# Fallback or handle other models if necessary, maybe raise an error
|
116 |
+
# For now, defaulting to Gemini if model name doesn't match others
|
117 |
+
print(f"Warning: Model '{model}' not explicitly handled. Defaulting to Gemini.")
|
118 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("gemini_api_key"))
|
119 |
return llm
|
120 |
|
121 |
|
|
|
163 |
valid_docs : list[Union[str, dict]]
|
164 |
docs_in_processing : list
|
165 |
process_steps : list[Union[str,dict]]
|
166 |
+
current_process_step : int
|
|