gabykim commited on
Commit
8e1fac7
·
1 Parent(s): 3536fc0

api key required for openai embeddings

Browse files
.vscode/launch.json CHANGED
@@ -2,13 +2,19 @@
2
  "version": "0.2.0",
3
  "configurations": [
4
  {
5
- "name": "Know Lang Bot",
6
  "type": "debugpy",
7
  "request": "launch",
8
  "module": "know_lang_bot",
9
  "args": [
10
  "--v"
11
  ]
 
 
 
 
 
 
12
  }
13
  ]
14
  }
 
2
  "version": "0.2.0",
3
  "configurations": [
4
  {
5
+ "name": "Know Lang",
6
  "type": "debugpy",
7
  "request": "launch",
8
  "module": "know_lang_bot",
9
  "args": [
10
  "--v"
11
  ]
12
+ },
13
+ {
14
+ "name": "Know Lang ChatBot ",
15
+ "type": "debugpy",
16
+ "request": "launch",
17
+ "program": "${workspaceFolder}/src/know_lang_bot/chat_bot/gradio_demo.py",
18
  }
19
  ]
20
  }
src/know_lang_bot/chat_bot/chat_graph.py CHANGED
@@ -154,7 +154,7 @@ class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
154
  try:
155
  question_embedding = generate_embedding(
156
  input=ctx.state.polished_question or ctx.state.original_question,
157
- model=ctx.deps.config.embedding
158
  )
159
 
160
  results = ctx.deps.collection.query(
 
154
  try:
155
  question_embedding = generate_embedding(
156
  input=ctx.state.polished_question or ctx.state.original_question,
157
+ config=ctx.deps.config.embedding
158
  )
159
 
160
  results = ctx.deps.collection.query(
src/know_lang_bot/config.py CHANGED
@@ -77,6 +77,18 @@ class EmbeddingConfig(BaseSettings):
77
  default_factory=dict,
78
  description="Provider-specific settings"
79
  )
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  class LLMConfig(BaseSettings):
82
  model_name: str = Field(
 
77
  default_factory=dict,
78
  description="Provider-specific settings"
79
  )
80
+ api_key: Optional[str] = Field(
81
+ default=None,
82
+ description="API key for the model provider"
83
+ )
84
+
85
+ @field_validator('api_key', mode='after')
86
+ @classmethod
87
+ def validate_api_key(cls, v: Optional[str], info: ValidationInfo) -> Optional[str]:
88
+ """Validate API key is present when required"""
89
+ if info.data['provider'] in [ModelProvider.OPENAI] and not v:
90
+ raise ValueError(f"API key required for {info.data['provider']}")
91
+ return v
92
 
93
  class LLMConfig(BaseSettings):
94
  model_name: str = Field(
src/know_lang_bot/models/embeddings.py CHANGED
@@ -6,11 +6,6 @@ from typing import Union, List, overload
6
  # Type definitions
7
  EmbeddingVector = List[float]
8
 
9
- class EmbeddingConfig:
10
- def __init__(self, provider: ModelProvider, model_name: str):
11
- self.provider = provider
12
- self.model_name = model_name
13
-
14
  def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
15
  """Helper function to process Ollama embeddings in batch."""
16
  return [
@@ -60,6 +55,7 @@ def generate_embedding(
60
  if config.provider == ModelProvider.OLLAMA:
61
  embeddings = _process_ollama_batch(inputs, config.model_name)
62
  elif config.provider == ModelProvider.OPENAI:
 
63
  embeddings = _process_openai_batch(inputs, config.model_name)
64
  else:
65
  raise ValueError(f"Unsupported provider: {config.provider}")
 
6
  # Type definitions
7
  EmbeddingVector = List[float]
8
 
 
 
 
 
 
9
  def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
10
  """Helper function to process Ollama embeddings in batch."""
11
  return [
 
55
  if config.provider == ModelProvider.OLLAMA:
56
  embeddings = _process_ollama_batch(inputs, config.model_name)
57
  elif config.provider == ModelProvider.OPENAI:
58
+ openai.api_key = config.api_key
59
  embeddings = _process_openai_batch(inputs, config.model_name)
60
  else:
61
  raise ValueError(f"Unsupported provider: {config.provider}")