ollama embedding for chromadb chunks
Browse files- poetry.lock +17 -1
- pyproject.toml +2 -1
- src/know_lang_bot/code_parser/summarizer.py +16 -0
- src/know_lang_bot/config.py +8 -0
- tests/test_summarizer.py +53 -2
poetry.lock
CHANGED
@@ -1493,6 +1493,22 @@ rsa = ["cryptography (>=3.0.0)"]
|
|
1493 |
signals = ["blinker (>=1.4.0)"]
|
1494 |
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
1495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1496 |
[[package]]
|
1497 |
name = "onnxruntime"
|
1498 |
version = "1.20.1"
|
@@ -3140,4 +3156,4 @@ type = ["pytest-mypy"]
|
|
3140 |
[metadata]
|
3141 |
lock-version = "2.1"
|
3142 |
python-versions = ">=3.10, <4.0"
|
3143 |
-
content-hash = "
|
|
|
1493 |
signals = ["blinker (>=1.4.0)"]
|
1494 |
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
1495 |
|
1496 |
+
[[package]]
|
1497 |
+
name = "ollama"
|
1498 |
+
version = "0.4.7"
|
1499 |
+
description = "The official Python client for Ollama."
|
1500 |
+
optional = false
|
1501 |
+
python-versions = "<4.0,>=3.8"
|
1502 |
+
groups = ["main"]
|
1503 |
+
files = [
|
1504 |
+
{file = "ollama-0.4.7-py3-none-any.whl", hash = "sha256:85505663cca67a83707be5fb3aeff0ea72e67846cea5985529d8eca4366564a1"},
|
1505 |
+
{file = "ollama-0.4.7.tar.gz", hash = "sha256:891dcbe54f55397d82d289c459de0ea897e103b86a3f1fad0fdb1895922a75ff"},
|
1506 |
+
]
|
1507 |
+
|
1508 |
+
[package.dependencies]
|
1509 |
+
httpx = ">=0.27,<0.29"
|
1510 |
+
pydantic = ">=2.9.0,<3.0.0"
|
1511 |
+
|
1512 |
[[package]]
|
1513 |
name = "onnxruntime"
|
1514 |
version = "1.20.1"
|
|
|
3156 |
[metadata]
|
3157 |
lock-version = "2.1"
|
3158 |
python-versions = ">=3.10, <4.0"
|
3159 |
+
content-hash = "2f776d5c10f8354dd6e7916b1453b1a5de7c28b93b8bbc287dbd874d1f9f1cee"
|
pyproject.toml
CHANGED
@@ -17,7 +17,8 @@ dependencies = [
|
|
17 |
"tree-sitter (>=0.24.0,<0.25.0)",
|
18 |
"tree-sitter-python (>=0.23.6,<0.24.0)",
|
19 |
"pydantic-settings (>=2.7.1,<3.0.0)",
|
20 |
-
"chromadb (>=0.6.3,<0.7.0)"
|
|
|
21 |
]
|
22 |
|
23 |
[tool.poetry]
|
|
|
17 |
"tree-sitter (>=0.24.0,<0.25.0)",
|
18 |
"tree-sitter-python (>=0.23.6,<0.24.0)",
|
19 |
"pydantic-settings (>=2.7.1,<3.0.0)",
|
20 |
+
"chromadb (>=0.6.3,<0.7.0)",
|
21 |
+
"ollama (>=0.4.7,<0.5.0)"
|
22 |
]
|
23 |
|
24 |
[tool.poetry]
|
src/know_lang_bot/code_parser/summarizer.py
CHANGED
@@ -3,6 +3,7 @@ import chromadb
|
|
3 |
from chromadb.errors import InvalidCollectionException
|
4 |
from pydantic_ai import Agent
|
5 |
from pydantic import BaseModel, Field
|
|
|
6 |
|
7 |
from know_lang_bot.config import AppConfig
|
8 |
from know_lang_bot.code_parser.parser import CodeChunk
|
@@ -62,6 +63,17 @@ class CodeSummarizer:
|
|
62 |
metadata={"hnsw:space": "cosine"}
|
63 |
)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
async def summarize_chunk(self, chunk: CodeChunk) -> str:
|
66 |
"""Summarize a single code chunk using the LLM"""
|
67 |
prompt = f"""
|
@@ -96,9 +108,13 @@ class CodeSummarizer:
|
|
96 |
docstring=chunk.docstring if chunk.docstring else ''
|
97 |
)
|
98 |
|
|
|
|
|
|
|
99 |
# Store in ChromaDB
|
100 |
self.collection.add(
|
101 |
documents=[summary],
|
|
|
102 |
metadatas=[metadata.model_dump()],
|
103 |
ids=[chunk_id]
|
104 |
)
|
|
|
3 |
from chromadb.errors import InvalidCollectionException
|
4 |
from pydantic_ai import Agent
|
5 |
from pydantic import BaseModel, Field
|
6 |
+
import ollama
|
7 |
|
8 |
from know_lang_bot.config import AppConfig
|
9 |
from know_lang_bot.code_parser.parser import CodeChunk
|
|
|
63 |
metadata={"hnsw:space": "cosine"}
|
64 |
)
|
65 |
|
66 |
+
def _get_embedding(self, text: str) -> List[float]:
|
67 |
+
"""Get embedding for text using configured provider"""
|
68 |
+
if self.config.llm.embedding_provider == "ollama":
|
69 |
+
response = ollama.embed(
|
70 |
+
model=self.config.llm.embedding_model,
|
71 |
+
input=text
|
72 |
+
)
|
73 |
+
return response['embedding']
|
74 |
+
else:
|
75 |
+
raise ValueError(f"Unsupported embedding provider: {self.config.llm.embedding_provider}")
|
76 |
+
|
77 |
async def summarize_chunk(self, chunk: CodeChunk) -> str:
|
78 |
"""Summarize a single code chunk using the LLM"""
|
79 |
prompt = f"""
|
|
|
108 |
docstring=chunk.docstring if chunk.docstring else ''
|
109 |
)
|
110 |
|
111 |
+
# Get embedding for the summary
|
112 |
+
embedding = self._get_embedding(summary)
|
113 |
+
|
114 |
# Store in ChromaDB
|
115 |
self.collection.add(
|
116 |
documents=[summary],
|
117 |
+
embeddings=[embedding],
|
118 |
metadatas=[metadata.model_dump()],
|
119 |
ids=[chunk_id]
|
120 |
)
|
src/know_lang_bot/config.py
CHANGED
@@ -20,6 +20,14 @@ class LLMConfig(BaseSettings):
|
|
20 |
default_factory=dict,
|
21 |
description="Additional model settings"
|
22 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
class DBConfig(BaseSettings):
|
25 |
persist_directory: Path = Field(
|
|
|
20 |
default_factory=dict,
|
21 |
description="Additional model settings"
|
22 |
)
|
23 |
+
embedding_model: str = Field(
|
24 |
+
default="mxbai-embed-large",
|
25 |
+
description="Name of the embedding model to use"
|
26 |
+
)
|
27 |
+
embedding_provider: str = Field(
|
28 |
+
default="ollama",
|
29 |
+
description="Provider for embeddings (ollama, openai, etc)"
|
30 |
+
)
|
31 |
|
32 |
class DBConfig(BaseSettings):
|
33 |
persist_directory: Path = Field(
|
tests/test_summarizer.py
CHANGED
@@ -2,7 +2,6 @@ import pytest
|
|
2 |
import tempfile
|
3 |
from unittest.mock import Mock, patch, AsyncMock
|
4 |
from pathlib import Path
|
5 |
-
from pydantic_ai import Agent
|
6 |
from know_lang_bot.code_parser.summarizer import CodeSummarizer
|
7 |
from know_lang_bot.code_parser.parser import CodeChunk, ChunkType
|
8 |
from know_lang_bot.config import AppConfig
|
@@ -83,4 +82,56 @@ def test_chromadb_initialization(mock_agent_class, config: AppConfig):
|
|
83 |
# Verify we can create a new collection
|
84 |
summarizer.db_client.delete_collection(config.db.collection_name)
|
85 |
new_summarizer = CodeSummarizer(config)
|
86 |
-
assert new_summarizer.collection is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import tempfile
|
3 |
from unittest.mock import Mock, patch, AsyncMock
|
4 |
from pathlib import Path
|
|
|
5 |
from know_lang_bot.code_parser.summarizer import CodeSummarizer
|
6 |
from know_lang_bot.code_parser.parser import CodeChunk, ChunkType
|
7 |
from know_lang_bot.config import AppConfig
|
|
|
82 |
# Verify we can create a new collection
|
83 |
summarizer.db_client.delete_collection(config.db.collection_name)
|
84 |
new_summarizer = CodeSummarizer(config)
|
85 |
+
assert new_summarizer.collection is not None
|
86 |
+
|
87 |
+
@pytest.mark.asyncio
|
88 |
+
@patch('know_lang_bot.code_parser.summarizer.ollama')
|
89 |
+
@patch('know_lang_bot.code_parser.summarizer.Agent')
|
90 |
+
async def test_process_and_store_chunk_with_embedding(
|
91 |
+
mock_agent_class,
|
92 |
+
mock_ollama,
|
93 |
+
config: AppConfig,
|
94 |
+
sample_chunks: list[CodeChunk],
|
95 |
+
mock_run_result: Mock
|
96 |
+
):
|
97 |
+
"""Test processing and storing a chunk with embedding"""
|
98 |
+
# Setup the mock agent instance
|
99 |
+
mock_agent = mock_agent_class.return_value
|
100 |
+
mock_agent.run = AsyncMock(return_value=mock_run_result)
|
101 |
+
|
102 |
+
# Setup mock embedding response
|
103 |
+
mock_embedding = {'embedding': [0.1, 0.2, 0.3]} # Sample embedding vector
|
104 |
+
mock_ollama.embed = Mock(return_value=mock_embedding)
|
105 |
+
|
106 |
+
summarizer = CodeSummarizer(config)
|
107 |
+
|
108 |
+
# Mock the collection's add method
|
109 |
+
summarizer.collection.add = Mock()
|
110 |
+
|
111 |
+
# Process the chunk
|
112 |
+
await summarizer.process_and_store_chunk(sample_chunks[0])
|
113 |
+
|
114 |
+
# Verify ollama.embed was called with correct parameters
|
115 |
+
mock_ollama.embed.assert_called_once_with(
|
116 |
+
model=config.llm.embedding_model,
|
117 |
+
input=mock_run_result.data
|
118 |
+
)
|
119 |
+
|
120 |
+
# Verify collection.add was called with correct parameters
|
121 |
+
add_call = summarizer.collection.add.call_args
|
122 |
+
assert add_call is not None
|
123 |
+
|
124 |
+
kwargs = add_call[1]
|
125 |
+
assert len(kwargs['embeddings']) == 1
|
126 |
+
assert kwargs['embeddings'][0] == mock_embedding['embedding']
|
127 |
+
assert kwargs['documents'][0] == mock_run_result.data
|
128 |
+
assert kwargs['ids'][0] == f"{sample_chunks[0].file_path}:{sample_chunks[0].start_line}-{sample_chunks[0].end_line}"
|
129 |
+
|
130 |
+
# Verify metadata
|
131 |
+
metadata = kwargs['metadatas'][0]
|
132 |
+
assert metadata['file_path'] == sample_chunks[0].file_path
|
133 |
+
assert metadata['start_line'] == sample_chunks[0].start_line
|
134 |
+
assert metadata['end_line'] == sample_chunks[0].end_line
|
135 |
+
assert metadata['type'] == sample_chunks[0].type.value
|
136 |
+
assert metadata['name'] == sample_chunks[0].name
|
137 |
+
assert metadata['docstring'] == sample_chunks[0].docstring
|