Spaces:
Sleeping
Sleeping
api key required for openai embeddings
Browse files
.vscode/launch.json
CHANGED
@@ -2,13 +2,19 @@
|
|
2 |
"version": "0.2.0",
|
3 |
"configurations": [
|
4 |
{
|
5 |
-
"name": "Know Lang
|
6 |
"type": "debugpy",
|
7 |
"request": "launch",
|
8 |
"module": "know_lang_bot",
|
9 |
"args": [
|
10 |
"--v"
|
11 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
}
|
13 |
]
|
14 |
}
|
|
|
2 |
"version": "0.2.0",
|
3 |
"configurations": [
|
4 |
{
|
5 |
+
"name": "Know Lang",
|
6 |
"type": "debugpy",
|
7 |
"request": "launch",
|
8 |
"module": "know_lang_bot",
|
9 |
"args": [
|
10 |
"--v"
|
11 |
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"name": "Know Lang ChatBot ",
|
15 |
+
"type": "debugpy",
|
16 |
+
"request": "launch",
|
17 |
+
"program": "${workspaceFolder}/src/know_lang_bot/chat_bot/gradio_demo.py",
|
18 |
}
|
19 |
]
|
20 |
}
|
src/know_lang_bot/chat_bot/chat_graph.py
CHANGED
@@ -154,7 +154,7 @@ class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
154 |
try:
|
155 |
question_embedding = generate_embedding(
|
156 |
input=ctx.state.polished_question or ctx.state.original_question,
|
157 |
-
|
158 |
)
|
159 |
|
160 |
results = ctx.deps.collection.query(
|
|
|
154 |
try:
|
155 |
question_embedding = generate_embedding(
|
156 |
input=ctx.state.polished_question or ctx.state.original_question,
|
157 |
+
config=ctx.deps.config.embedding
|
158 |
)
|
159 |
|
160 |
results = ctx.deps.collection.query(
|
src/know_lang_bot/config.py
CHANGED
@@ -77,6 +77,18 @@ class EmbeddingConfig(BaseSettings):
|
|
77 |
default_factory=dict,
|
78 |
description="Provider-specific settings"
|
79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
class LLMConfig(BaseSettings):
|
82 |
model_name: str = Field(
|
|
|
77 |
default_factory=dict,
|
78 |
description="Provider-specific settings"
|
79 |
)
|
80 |
+
api_key: Optional[str] = Field(
|
81 |
+
default=None,
|
82 |
+
description="API key for the model provider"
|
83 |
+
)
|
84 |
+
|
85 |
+
@field_validator('api_key', mode='after')
|
86 |
+
@classmethod
|
87 |
+
def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
|
88 |
+
"""Validate API key is present when required"""
|
89 |
+
if info.data['provider'] in [ModelProvider.OPENAI] and not v:
|
90 |
+
raise ValueError(f"API key required for {info.data['provider']}")
|
91 |
+
return v
|
92 |
|
93 |
class LLMConfig(BaseSettings):
|
94 |
model_name: str = Field(
|
src/know_lang_bot/models/embeddings.py
CHANGED
@@ -6,11 +6,6 @@ from typing import Union, List, overload
|
|
6 |
# Type definitions
|
7 |
EmbeddingVector = List[float]
|
8 |
|
9 |
-
class EmbeddingConfig:
|
10 |
-
def __init__(self, provider: ModelProvider, model_name: str):
|
11 |
-
self.provider = provider
|
12 |
-
self.model_name = model_name
|
13 |
-
|
14 |
def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
|
15 |
"""Helper function to process Ollama embeddings in batch."""
|
16 |
return [
|
@@ -60,6 +55,7 @@ def generate_embedding(
|
|
60 |
if config.provider == ModelProvider.OLLAMA:
|
61 |
embeddings = _process_ollama_batch(inputs, config.model_name)
|
62 |
elif config.provider == ModelProvider.OPENAI:
|
|
|
63 |
embeddings = _process_openai_batch(inputs, config.model_name)
|
64 |
else:
|
65 |
raise ValueError(f"Unsupported provider: {config.provider}")
|
|
|
6 |
# Type definitions
|
7 |
EmbeddingVector = List[float]
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
|
10 |
"""Helper function to process Ollama embeddings in batch."""
|
11 |
return [
|
|
|
55 |
if config.provider == ModelProvider.OLLAMA:
|
56 |
embeddings = _process_ollama_batch(inputs, config.model_name)
|
57 |
elif config.provider == ModelProvider.OPENAI:
|
58 |
+
openai.api_key = config.api_key
|
59 |
embeddings = _process_openai_batch(inputs, config.model_name)
|
60 |
else:
|
61 |
raise ValueError(f"Unsupported provider: {config.provider}")
|