gabykim commited on
Commit
369993f
·
1 Parent(s): 27ad088

ollama embedding for chromadb chunks

Browse files
poetry.lock CHANGED
@@ -1493,6 +1493,22 @@ rsa = ["cryptography (>=3.0.0)"]
1493
  signals = ["blinker (>=1.4.0)"]
1494
  signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
1495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1496
  [[package]]
1497
  name = "onnxruntime"
1498
  version = "1.20.1"
@@ -3140,4 +3156,4 @@ type = ["pytest-mypy"]
3140
  [metadata]
3141
  lock-version = "2.1"
3142
  python-versions = ">=3.10, <4.0"
3143
- content-hash = "e832a3ea167213ca280f213201124d535205b11ddadd8f4affbbdf0431a78906"
 
1493
  signals = ["blinker (>=1.4.0)"]
1494
  signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
1495
 
1496
+ [[package]]
1497
+ name = "ollama"
1498
+ version = "0.4.7"
1499
+ description = "The official Python client for Ollama."
1500
+ optional = false
1501
+ python-versions = "<4.0,>=3.8"
1502
+ groups = ["main"]
1503
+ files = [
1504
+ {file = "ollama-0.4.7-py3-none-any.whl", hash = "sha256:85505663cca67a83707be5fb3aeff0ea72e67846cea5985529d8eca4366564a1"},
1505
+ {file = "ollama-0.4.7.tar.gz", hash = "sha256:891dcbe54f55397d82d289c459de0ea897e103b86a3f1fad0fdb1895922a75ff"},
1506
+ ]
1507
+
1508
+ [package.dependencies]
1509
+ httpx = ">=0.27,<0.29"
1510
+ pydantic = ">=2.9.0,<3.0.0"
1511
+
1512
  [[package]]
1513
  name = "onnxruntime"
1514
  version = "1.20.1"
 
3156
  [metadata]
3157
  lock-version = "2.1"
3158
  python-versions = ">=3.10, <4.0"
3159
+ content-hash = "2f776d5c10f8354dd6e7916b1453b1a5de7c28b93b8bbc287dbd874d1f9f1cee"
pyproject.toml CHANGED
@@ -17,7 +17,8 @@ dependencies = [
17
  "tree-sitter (>=0.24.0,<0.25.0)",
18
  "tree-sitter-python (>=0.23.6,<0.24.0)",
19
  "pydantic-settings (>=2.7.1,<3.0.0)",
20
- "chromadb (>=0.6.3,<0.7.0)"
 
21
  ]
22
 
23
  [tool.poetry]
 
17
  "tree-sitter (>=0.24.0,<0.25.0)",
18
  "tree-sitter-python (>=0.23.6,<0.24.0)",
19
  "pydantic-settings (>=2.7.1,<3.0.0)",
20
+ "chromadb (>=0.6.3,<0.7.0)",
21
+ "ollama (>=0.4.7,<0.5.0)"
22
  ]
23
 
24
  [tool.poetry]
src/know_lang_bot/code_parser/summarizer.py CHANGED
@@ -3,6 +3,7 @@ import chromadb
3
  from chromadb.errors import InvalidCollectionException
4
  from pydantic_ai import Agent
5
  from pydantic import BaseModel, Field
 
6
 
7
  from know_lang_bot.config import AppConfig
8
  from know_lang_bot.code_parser.parser import CodeChunk
@@ -62,6 +63,17 @@ class CodeSummarizer:
62
  metadata={"hnsw:space": "cosine"}
63
  )
64
 
 
 
 
 
 
 
 
 
 
 
 
65
  async def summarize_chunk(self, chunk: CodeChunk) -> str:
66
  """Summarize a single code chunk using the LLM"""
67
  prompt = f"""
@@ -96,9 +108,13 @@ class CodeSummarizer:
96
  docstring=chunk.docstring if chunk.docstring else ''
97
  )
98
 
 
 
 
99
  # Store in ChromaDB
100
  self.collection.add(
101
  documents=[summary],
 
102
  metadatas=[metadata.model_dump()],
103
  ids=[chunk_id]
104
  )
 
3
  from chromadb.errors import InvalidCollectionException
4
  from pydantic_ai import Agent
5
  from pydantic import BaseModel, Field
6
+ import ollama
7
 
8
  from know_lang_bot.config import AppConfig
9
  from know_lang_bot.code_parser.parser import CodeChunk
 
63
  metadata={"hnsw:space": "cosine"}
64
  )
65
 
66
+ def _get_embedding(self, text: str) -> List[float]:
67
+ """Get embedding for text using configured provider"""
68
+ if self.config.llm.embedding_provider == "ollama":
69
+ response = ollama.embed(
70
+ model=self.config.llm.embedding_model,
71
+ input=text
72
+ )
73
+ return response['embedding']
74
+ else:
75
+ raise ValueError(f"Unsupported embedding provider: {self.config.llm.embedding_provider}")
76
+
77
  async def summarize_chunk(self, chunk: CodeChunk) -> str:
78
  """Summarize a single code chunk using the LLM"""
79
  prompt = f"""
 
108
  docstring=chunk.docstring if chunk.docstring else ''
109
  )
110
 
111
+ # Get embedding for the summary
112
+ embedding = self._get_embedding(summary)
113
+
114
  # Store in ChromaDB
115
  self.collection.add(
116
  documents=[summary],
117
+ embeddings=[embedding],
118
  metadatas=[metadata.model_dump()],
119
  ids=[chunk_id]
120
  )
src/know_lang_bot/config.py CHANGED
@@ -20,6 +20,14 @@ class LLMConfig(BaseSettings):
20
  default_factory=dict,
21
  description="Additional model settings"
22
  )
 
 
 
 
 
 
 
 
23
 
24
  class DBConfig(BaseSettings):
25
  persist_directory: Path = Field(
 
20
  default_factory=dict,
21
  description="Additional model settings"
22
  )
23
+ embedding_model: str = Field(
24
+ default="mxbai-embed-large",
25
+ description="Name of the embedding model to use"
26
+ )
27
+ embedding_provider: str = Field(
28
+ default="ollama",
29
+ description="Provider for embeddings (ollama, openai, etc)"
30
+ )
31
 
32
  class DBConfig(BaseSettings):
33
  persist_directory: Path = Field(
tests/test_summarizer.py CHANGED
@@ -2,7 +2,6 @@ import pytest
2
  import tempfile
3
  from unittest.mock import Mock, patch, AsyncMock
4
  from pathlib import Path
5
- from pydantic_ai import Agent
6
  from know_lang_bot.code_parser.summarizer import CodeSummarizer
7
  from know_lang_bot.code_parser.parser import CodeChunk, ChunkType
8
  from know_lang_bot.config import AppConfig
@@ -83,4 +82,56 @@ def test_chromadb_initialization(mock_agent_class, config: AppConfig):
83
  # Verify we can create a new collection
84
  summarizer.db_client.delete_collection(config.db.collection_name)
85
  new_summarizer = CodeSummarizer(config)
86
- assert new_summarizer.collection is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import tempfile
3
  from unittest.mock import Mock, patch, AsyncMock
4
  from pathlib import Path
 
5
  from know_lang_bot.code_parser.summarizer import CodeSummarizer
6
  from know_lang_bot.code_parser.parser import CodeChunk, ChunkType
7
  from know_lang_bot.config import AppConfig
 
82
  # Verify we can create a new collection
83
  summarizer.db_client.delete_collection(config.db.collection_name)
84
  new_summarizer = CodeSummarizer(config)
85
+ assert new_summarizer.collection is not None
86
+
87
+ @pytest.mark.asyncio
88
+ @patch('know_lang_bot.code_parser.summarizer.ollama')
89
+ @patch('know_lang_bot.code_parser.summarizer.Agent')
90
+ async def test_process_and_store_chunk_with_embedding(
91
+ mock_agent_class,
92
+ mock_ollama,
93
+ config: AppConfig,
94
+ sample_chunks: list[CodeChunk],
95
+ mock_run_result: Mock
96
+ ):
97
+ """Test processing and storing a chunk with embedding"""
98
+ # Setup the mock agent instance
99
+ mock_agent = mock_agent_class.return_value
100
+ mock_agent.run = AsyncMock(return_value=mock_run_result)
101
+
102
+ # Setup mock embedding response
103
+ mock_embedding = {'embedding': [0.1, 0.2, 0.3]} # Sample embedding vector
104
+ mock_ollama.embed = Mock(return_value=mock_embedding)
105
+
106
+ summarizer = CodeSummarizer(config)
107
+
108
+ # Mock the collection's add method
109
+ summarizer.collection.add = Mock()
110
+
111
+ # Process the chunk
112
+ await summarizer.process_and_store_chunk(sample_chunks[0])
113
+
114
+ # Verify ollama.embed was called with correct parameters
115
+ mock_ollama.embed.assert_called_once_with(
116
+ model=config.llm.embedding_model,
117
+ input=mock_run_result.data
118
+ )
119
+
120
+ # Verify collection.add was called with correct parameters
121
+ add_call = summarizer.collection.add.call_args
122
+ assert add_call is not None
123
+
124
+ kwargs = add_call[1]
125
+ assert len(kwargs['embeddings']) == 1
126
+ assert kwargs['embeddings'][0] == mock_embedding['embedding']
127
+ assert kwargs['documents'][0] == mock_run_result.data
128
+ assert kwargs['ids'][0] == f"{sample_chunks[0].file_path}:{sample_chunks[0].start_line}-{sample_chunks[0].end_line}"
129
+
130
+ # Verify metadata
131
+ metadata = kwargs['metadatas'][0]
132
+ assert metadata['file_path'] == sample_chunks[0].file_path
133
+ assert metadata['start_line'] == sample_chunks[0].start_line
134
+ assert metadata['end_line'] == sample_chunks[0].end_line
135
+ assert metadata['type'] == sample_chunks[0].type.value
136
+ assert metadata['name'] == sample_chunks[0].name
137
+ assert metadata['docstring'] == sample_chunks[0].docstring