gabykim commited on
Commit
3536fc0
·
1 Parent(s): d88e3a5

multiple embedding provider implemented

Browse files
src/know_lang_bot/chat_bot/chat_graph.py CHANGED
@@ -5,15 +5,15 @@ from typing import AsyncGenerator, List, Dict, Any, Optional
5
  import chromadb
6
  from pydantic import BaseModel
7
  from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep
8
- import ollama
9
  from know_lang_bot.config import AppConfig
10
  from know_lang_bot.utils.fancy_log import FancyLogger
11
  from pydantic_ai import Agent
12
  import logfire
13
- from pprint import pformat
14
  from enum import Enum
15
  from rich.console import Console
16
  from know_lang_bot.utils.model_provider import create_pydantic_model
 
17
 
18
  LOG = FancyLogger(__name__)
19
  console = Console()
@@ -152,17 +152,17 @@ class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
152
 
153
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
154
  try:
155
- embedded_question = ollama.embed(
156
- model=ctx.deps.config.embedding.model_name,
157
- input=ctx.state.polished_question or ctx.state.original_question
158
  )
159
 
160
  results = ctx.deps.collection.query(
161
- query_embeddings=embedded_question['embeddings'],
162
  n_results=ctx.deps.config.chat.max_context_chunks,
163
  include=['metadatas', 'documents', 'distances']
164
  )
165
- logfire.debug('query result: {result}', result=pformat(results))
166
 
167
  relevant_chunks = []
168
  relevant_metadatas = []
 
5
  import chromadb
6
  from pydantic import BaseModel
7
  from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep
 
8
  from know_lang_bot.config import AppConfig
9
  from know_lang_bot.utils.fancy_log import FancyLogger
10
  from pydantic_ai import Agent
11
  import logfire
12
+ from rich.pretty import Pretty
13
  from enum import Enum
14
  from rich.console import Console
15
  from know_lang_bot.utils.model_provider import create_pydantic_model
16
+ from know_lang_bot.models.embeddings import generate_embedding
17
 
18
  LOG = FancyLogger(__name__)
19
  console = Console()
 
152
 
153
  async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
154
  try:
155
+ question_embedding = generate_embedding(
156
+ input=ctx.state.polished_question or ctx.state.original_question,
157
+ model=ctx.deps.config.embedding
158
  )
159
 
160
  results = ctx.deps.collection.query(
161
+ query_embeddings=question_embedding,
162
  n_results=ctx.deps.config.chat.max_context_chunks,
163
  include=['metadatas', 'documents', 'distances']
164
  )
165
+ logfire.debug('query result: {result}', result=Pretty(results))
166
 
167
  relevant_chunks = []
168
  relevant_metadatas = []
src/know_lang_bot/core/types.py CHANGED
@@ -23,4 +23,5 @@ class ModelProvider(str, Enum):
23
  OPENAI = "openai"
24
  ANTHROPIC = "anthropic"
25
  OLLAMA = "ollama"
26
- HUGGINGFACE = "huggingface"
 
 
23
  OPENAI = "openai"
24
  ANTHROPIC = "anthropic"
25
  OLLAMA = "ollama"
26
+ HUGGINGFACE = "huggingface"
27
+ TESTING = "testing"
src/know_lang_bot/models/embeddings.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ollama
2
+ import openai
3
+ from know_lang_bot.config import EmbeddingConfig, ModelProvider
4
+ from typing import Union, List, overload
5
+
6
+ # Type definitions
7
+ EmbeddingVector = List[float]
8
+
9
+ class EmbeddingConfig:
10
+ def __init__(self, provider: ModelProvider, model_name: str):
11
+ self.provider = provider
12
+ self.model_name = model_name
13
+
14
+ def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
15
+ """Helper function to process Ollama embeddings in batch."""
16
+ return [
17
+ ollama.embed(model=model_name, input=inputs)['embeddings']
18
+ ]
19
+
20
+ def _process_openai_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
21
+ """Helper function to process OpenAI embeddings in batch."""
22
+ response = openai.embeddings.create(
23
+ input=inputs,
24
+ model=model_name
25
+ )
26
+ return [item.embedding for item in response.data]
27
+
28
+ @overload
29
+ def generate_embedding(input: str, config: EmbeddingConfig) -> EmbeddingVector: ...
30
+
31
+ @overload
32
+ def generate_embedding(input: List[str], config: EmbeddingConfig) -> List[EmbeddingVector]: ...
33
+
34
+ def generate_embedding(
35
+ input: Union[str, List[str]],
36
+ config: EmbeddingConfig
37
+ ) -> Union[EmbeddingVector, List[EmbeddingVector]]:
38
+ """
39
+ Generate embeddings for single text input or batch of texts.
40
+
41
+ Args:
42
+ input: Single string or list of strings to embed
43
+ config: Configuration object containing provider and model information
44
+
45
+ Returns:
46
+ Single embedding vector for single input, or list of embedding vectors for batch input
47
+
48
+ Raises:
49
+ ValueError: If input type is invalid or provider is not supported
50
+ RuntimeError: If embedding generation fails
51
+ """
52
+ if not input:
53
+ raise ValueError("Input cannot be empty")
54
+
55
+ # Convert single string to list for batch processing
56
+ is_single_input = isinstance(input, str)
57
+ inputs = [input] if is_single_input else input
58
+
59
+ try:
60
+ if config.provider == ModelProvider.OLLAMA:
61
+ embeddings = _process_ollama_batch(inputs, config.model_name)
62
+ elif config.provider == ModelProvider.OPENAI:
63
+ embeddings = _process_openai_batch(inputs, config.model_name)
64
+ else:
65
+ raise ValueError(f"Unsupported provider: {config.provider}")
66
+
67
+ # Return single embedding for single input
68
+ return embeddings[0] if is_single_input else embeddings
69
+
70
+ except Exception as e:
71
+ raise RuntimeError(f"Failed to generate embeddings: {str(e)}") from e
src/know_lang_bot/summarizer/summarizer.py CHANGED
@@ -3,7 +3,6 @@ import chromadb
3
  from chromadb.errors import InvalidCollectionException
4
  from pydantic_ai import Agent
5
  from pydantic import BaseModel, Field
6
- import ollama
7
  from pprint import pformat
8
  from rich.progress import Progress
9
 
@@ -11,6 +10,7 @@ from know_lang_bot.config import AppConfig
11
  from know_lang_bot.core.types import CodeChunk, ModelProvider
12
  from know_lang_bot.utils.fancy_log import FancyLogger
13
  from know_lang_bot.utils.model_provider import create_pydantic_model
 
14
 
15
  LOG = FancyLogger(__name__)
16
 
@@ -76,17 +76,6 @@ Provide a clean, concise and focused summary. Don't include unnecessary nor gene
76
  metadata={"hnsw:space": "cosine"}
77
  )
78
 
79
- def _get_embedding(self, text: str) -> List[float]:
80
- """Get embedding for text using configured provider"""
81
- if self.config.embedding.provider == ModelProvider.OLLAMA:
82
- response = ollama.embed(
83
- model=self.config.embedding.model_name,
84
- input=text
85
- )
86
- return response['embeddings']
87
- else:
88
- raise ValueError(f"Unsupported embedding provider: {self.config.embedding.provider}")
89
-
90
  async def summarize_chunk(self, chunk: CodeChunk) -> str:
91
  """Summarize a single code chunk using the LLM"""
92
  prompt = f"""
@@ -122,7 +111,7 @@ Provide a clean, concise and focused summary. Don't include unnecessary nor gene
122
  )
123
 
124
  # Get embedding for the summary
125
- embedding = self._get_embedding(summary)
126
 
127
  # Store in ChromaDB
128
  self.collection.add(
 
3
  from chromadb.errors import InvalidCollectionException
4
  from pydantic_ai import Agent
5
  from pydantic import BaseModel, Field
 
6
  from pprint import pformat
7
  from rich.progress import Progress
8
 
 
10
  from know_lang_bot.core.types import CodeChunk, ModelProvider
11
  from know_lang_bot.utils.fancy_log import FancyLogger
12
  from know_lang_bot.utils.model_provider import create_pydantic_model
13
+ from know_lang_bot.models.embeddings import generate_embedding
14
 
15
  LOG = FancyLogger(__name__)
16
 
 
76
  metadata={"hnsw:space": "cosine"}
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
 
79
  async def summarize_chunk(self, chunk: CodeChunk) -> str:
80
  """Summarize a single code chunk using the LLM"""
81
  prompt = f"""
 
111
  )
112
 
113
  # Get embedding for the summary
114
+ embedding = generate_embedding(summary, self.config.embedding)
115
 
116
  # Store in ChromaDB
117
  self.collection.add(
src/know_lang_bot/utils/model_provider.py CHANGED
@@ -13,5 +13,8 @@ def create_pydantic_model(
13
  return model_str
14
  elif model_provider == ModelProvider.HUGGINGFACE:
15
  return HuggingFaceModel(model_name=model_name)
 
 
 
16
  else:
17
  raise NotImplementedError(f"Model {model_provider}:{model_name} is not supported")
 
13
  return model_str
14
  elif model_provider == ModelProvider.HUGGINGFACE:
15
  return HuggingFaceModel(model_name=model_name)
16
+ elif model_provider == ModelProvider.TESTING:
17
+ # should be used for testing purposes only
18
+ pass
19
  else:
20
  raise NotImplementedError(f"Model {model_provider}:{model_name} is not supported")
tests/test_summarizer.py CHANGED
@@ -11,7 +11,7 @@ def config():
11
  """Create a test configuration"""
12
  with tempfile.TemporaryDirectory() as temp_dir:
13
  yield AppConfig(
14
- llm={"model_name": "test-model", "model_provider": "test"},
15
  db={"persist_directory": Path(temp_dir), "collection_name": "test_collection"}
16
  )
17
 
@@ -85,11 +85,11 @@ def test_chromadb_initialization(mock_agent_class, config: AppConfig):
85
  assert new_summarizer.collection is not None
86
 
87
  @pytest.mark.asyncio
88
- @patch('know_lang_bot.summarizer.summarizer.ollama')
89
  @patch('know_lang_bot.summarizer.summarizer.Agent')
90
  async def test_process_and_store_chunk_with_embedding(
91
  mock_agent_class,
92
- mock_ollama,
93
  config: AppConfig,
94
  sample_chunks: list[CodeChunk],
95
  mock_run_result: Mock
@@ -100,8 +100,8 @@ async def test_process_and_store_chunk_with_embedding(
100
  mock_agent.run = AsyncMock(return_value=mock_run_result)
101
 
102
  # Setup mock embedding response
103
- mock_embedding = {'embeddings': [0.1, 0.2, 0.3]} # Sample embedding vector
104
- mock_ollama.embed = Mock(return_value=mock_embedding)
105
 
106
  summarizer = CodeSummarizer(config)
107
 
@@ -112,9 +112,9 @@ async def test_process_and_store_chunk_with_embedding(
112
  await summarizer.process_and_store_chunk(sample_chunks[0])
113
 
114
  # Verify ollama.embed was called with correct parameters
115
- mock_ollama.embed.assert_called_once_with(
116
- model=config.embedding.model_name,
117
- input=mock_run_result.data
118
  )
119
 
120
  # Verify collection.add was called with correct parameters
@@ -123,7 +123,7 @@ async def test_process_and_store_chunk_with_embedding(
123
 
124
  kwargs = add_call[1]
125
  assert len(kwargs['embeddings']) == 3
126
- assert kwargs['embeddings'] == mock_embedding['embeddings']
127
  assert kwargs['documents'][0] == mock_run_result.data
128
  assert kwargs['ids'][0] == f"{sample_chunks[0].file_path}:{sample_chunks[0].start_line}-{sample_chunks[0].end_line}"
129
 
 
11
  """Create a test configuration"""
12
  with tempfile.TemporaryDirectory() as temp_dir:
13
  yield AppConfig(
14
+ llm={"model_name": "testing", "model_provider": "testing"},
15
  db={"persist_directory": Path(temp_dir), "collection_name": "test_collection"}
16
  )
17
 
 
85
  assert new_summarizer.collection is not None
86
 
87
  @pytest.mark.asyncio
88
+ @patch('know_lang_bot.summarizer.summarizer.generate_embedding')
89
  @patch('know_lang_bot.summarizer.summarizer.Agent')
90
  async def test_process_and_store_chunk_with_embedding(
91
  mock_agent_class,
92
+ mock_embedding_generator,
93
  config: AppConfig,
94
  sample_chunks: list[CodeChunk],
95
  mock_run_result: Mock
 
100
  mock_agent.run = AsyncMock(return_value=mock_run_result)
101
 
102
  # Setup mock embedding response
103
+ mock_embedding = [0.1, 0.2, 0.3] # Sample embedding vector
104
+ mock_embedding_generator.return_value = mock_embedding
105
 
106
  summarizer = CodeSummarizer(config)
107
 
 
112
  await summarizer.process_and_store_chunk(sample_chunks[0])
113
 
114
  # Verify ollama.embed was called with correct parameters
115
+ mock_embedding_generator.assert_called_once_with(
116
+ mock_run_result.data,
117
+ config.embedding,
118
  )
119
 
120
  # Verify collection.add was called with correct parameters
 
123
 
124
  kwargs = add_call[1]
125
  assert len(kwargs['embeddings']) == 3
126
+ assert kwargs['embeddings'] == mock_embedding
127
  assert kwargs['documents'][0] == mock_run_result.data
128
  assert kwargs['ids'][0] == f"{sample_chunks[0].file_path}:{sample_chunks[0].start_line}-{sample_chunks[0].end_line}"
129