Spaces:
Sleeping
Sleeping
multiple embedding provider implemented
Browse files
src/know_lang_bot/chat_bot/chat_graph.py
CHANGED
@@ -5,15 +5,15 @@ from typing import AsyncGenerator, List, Dict, Any, Optional
|
|
5 |
import chromadb
|
6 |
from pydantic import BaseModel
|
7 |
from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep
|
8 |
-
import ollama
|
9 |
from know_lang_bot.config import AppConfig
|
10 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
11 |
from pydantic_ai import Agent
|
12 |
import logfire
|
13 |
-
from
|
14 |
from enum import Enum
|
15 |
from rich.console import Console
|
16 |
from know_lang_bot.utils.model_provider import create_pydantic_model
|
|
|
17 |
|
18 |
LOG = FancyLogger(__name__)
|
19 |
console = Console()
|
@@ -152,17 +152,17 @@ class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
|
|
152 |
|
153 |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
|
154 |
try:
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
)
|
159 |
|
160 |
results = ctx.deps.collection.query(
|
161 |
-
query_embeddings=
|
162 |
n_results=ctx.deps.config.chat.max_context_chunks,
|
163 |
include=['metadatas', 'documents', 'distances']
|
164 |
)
|
165 |
-
logfire.debug('query result: {result}', result=
|
166 |
|
167 |
relevant_chunks = []
|
168 |
relevant_metadatas = []
|
|
|
5 |
import chromadb
|
6 |
from pydantic import BaseModel
|
7 |
from pydantic_graph import BaseNode, EndStep, Graph, GraphRunContext, End, HistoryStep
|
|
|
8 |
from know_lang_bot.config import AppConfig
|
9 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
10 |
from pydantic_ai import Agent
|
11 |
import logfire
|
12 |
+
from rich.pretty import Pretty
|
13 |
from enum import Enum
|
14 |
from rich.console import Console
|
15 |
from know_lang_bot.utils.model_provider import create_pydantic_model
|
16 |
+
from know_lang_bot.models.embeddings import generate_embedding
|
17 |
|
18 |
LOG = FancyLogger(__name__)
|
19 |
console = Console()
|
|
|
152 |
|
153 |
async def run(self, ctx: GraphRunContext[ChatGraphState, ChatGraphDeps]) -> AnswerQuestionNode:
|
154 |
try:
|
155 |
+
question_embedding = generate_embedding(
|
156 |
+
input=ctx.state.polished_question or ctx.state.original_question,
|
157 |
+
model=ctx.deps.config.embedding
|
158 |
)
|
159 |
|
160 |
results = ctx.deps.collection.query(
|
161 |
+
query_embeddings=question_embedding,
|
162 |
n_results=ctx.deps.config.chat.max_context_chunks,
|
163 |
include=['metadatas', 'documents', 'distances']
|
164 |
)
|
165 |
+
logfire.debug('query result: {result}', result=Pretty(results))
|
166 |
|
167 |
relevant_chunks = []
|
168 |
relevant_metadatas = []
|
src/know_lang_bot/core/types.py
CHANGED
@@ -23,4 +23,5 @@ class ModelProvider(str, Enum):
|
|
23 |
OPENAI = "openai"
|
24 |
ANTHROPIC = "anthropic"
|
25 |
OLLAMA = "ollama"
|
26 |
-
HUGGINGFACE = "huggingface"
|
|
|
|
23 |
OPENAI = "openai"
|
24 |
ANTHROPIC = "anthropic"
|
25 |
OLLAMA = "ollama"
|
26 |
+
HUGGINGFACE = "huggingface"
|
27 |
+
TESTING = "testing"
|
src/know_lang_bot/models/embeddings.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ollama
|
2 |
+
import openai
|
3 |
+
from know_lang_bot.config import EmbeddingConfig, ModelProvider
|
4 |
+
from typing import Union, List, overload
|
5 |
+
|
6 |
+
# Type definitions
|
7 |
+
EmbeddingVector = List[float]
|
8 |
+
|
9 |
+
class EmbeddingConfig:
|
10 |
+
def __init__(self, provider: ModelProvider, model_name: str):
|
11 |
+
self.provider = provider
|
12 |
+
self.model_name = model_name
|
13 |
+
|
14 |
+
def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
|
15 |
+
"""Helper function to process Ollama embeddings in batch."""
|
16 |
+
return [
|
17 |
+
ollama.embed(model=model_name, input=inputs)['embeddings']
|
18 |
+
]
|
19 |
+
|
20 |
+
def _process_openai_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
|
21 |
+
"""Helper function to process OpenAI embeddings in batch."""
|
22 |
+
response = openai.embeddings.create(
|
23 |
+
input=inputs,
|
24 |
+
model=model_name
|
25 |
+
)
|
26 |
+
return [item.embedding for item in response.data]
|
27 |
+
|
28 |
+
@overload
|
29 |
+
def generate_embedding(input: str, config: EmbeddingConfig) -> EmbeddingVector: ...
|
30 |
+
|
31 |
+
@overload
|
32 |
+
def generate_embedding(input: List[str], config: EmbeddingConfig) -> List[EmbeddingVector]: ...
|
33 |
+
|
34 |
+
def generate_embedding(
|
35 |
+
input: Union[str, List[str]],
|
36 |
+
config: EmbeddingConfig
|
37 |
+
) -> Union[EmbeddingVector, List[EmbeddingVector]]:
|
38 |
+
"""
|
39 |
+
Generate embeddings for single text input or batch of texts.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
input: Single string or list of strings to embed
|
43 |
+
config: Configuration object containing provider and model information
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Single embedding vector for single input, or list of embedding vectors for batch input
|
47 |
+
|
48 |
+
Raises:
|
49 |
+
ValueError: If input type is invalid or provider is not supported
|
50 |
+
RuntimeError: If embedding generation fails
|
51 |
+
"""
|
52 |
+
if not input:
|
53 |
+
raise ValueError("Input cannot be empty")
|
54 |
+
|
55 |
+
# Convert single string to list for batch processing
|
56 |
+
is_single_input = isinstance(input, str)
|
57 |
+
inputs = [input] if is_single_input else input
|
58 |
+
|
59 |
+
try:
|
60 |
+
if config.provider == ModelProvider.OLLAMA:
|
61 |
+
embeddings = _process_ollama_batch(inputs, config.model_name)
|
62 |
+
elif config.provider == ModelProvider.OPENAI:
|
63 |
+
embeddings = _process_openai_batch(inputs, config.model_name)
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Unsupported provider: {config.provider}")
|
66 |
+
|
67 |
+
# Return single embedding for single input
|
68 |
+
return embeddings[0] if is_single_input else embeddings
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
raise RuntimeError(f"Failed to generate embeddings: {str(e)}") from e
|
src/know_lang_bot/summarizer/summarizer.py
CHANGED
@@ -3,7 +3,6 @@ import chromadb
|
|
3 |
from chromadb.errors import InvalidCollectionException
|
4 |
from pydantic_ai import Agent
|
5 |
from pydantic import BaseModel, Field
|
6 |
-
import ollama
|
7 |
from pprint import pformat
|
8 |
from rich.progress import Progress
|
9 |
|
@@ -11,6 +10,7 @@ from know_lang_bot.config import AppConfig
|
|
11 |
from know_lang_bot.core.types import CodeChunk, ModelProvider
|
12 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
13 |
from know_lang_bot.utils.model_provider import create_pydantic_model
|
|
|
14 |
|
15 |
LOG = FancyLogger(__name__)
|
16 |
|
@@ -76,17 +76,6 @@ Provide a clean, concise and focused summary. Don't include unnecessary nor gene
|
|
76 |
metadata={"hnsw:space": "cosine"}
|
77 |
)
|
78 |
|
79 |
-
def _get_embedding(self, text: str) -> List[float]:
|
80 |
-
"""Get embedding for text using configured provider"""
|
81 |
-
if self.config.embedding.provider == ModelProvider.OLLAMA:
|
82 |
-
response = ollama.embed(
|
83 |
-
model=self.config.embedding.model_name,
|
84 |
-
input=text
|
85 |
-
)
|
86 |
-
return response['embeddings']
|
87 |
-
else:
|
88 |
-
raise ValueError(f"Unsupported embedding provider: {self.config.embedding.provider}")
|
89 |
-
|
90 |
async def summarize_chunk(self, chunk: CodeChunk) -> str:
|
91 |
"""Summarize a single code chunk using the LLM"""
|
92 |
prompt = f"""
|
@@ -122,7 +111,7 @@ Provide a clean, concise and focused summary. Don't include unnecessary nor gene
|
|
122 |
)
|
123 |
|
124 |
# Get embedding for the summary
|
125 |
-
embedding = self.
|
126 |
|
127 |
# Store in ChromaDB
|
128 |
self.collection.add(
|
|
|
3 |
from chromadb.errors import InvalidCollectionException
|
4 |
from pydantic_ai import Agent
|
5 |
from pydantic import BaseModel, Field
|
|
|
6 |
from pprint import pformat
|
7 |
from rich.progress import Progress
|
8 |
|
|
|
10 |
from know_lang_bot.core.types import CodeChunk, ModelProvider
|
11 |
from know_lang_bot.utils.fancy_log import FancyLogger
|
12 |
from know_lang_bot.utils.model_provider import create_pydantic_model
|
13 |
+
from know_lang_bot.models.embeddings import generate_embedding
|
14 |
|
15 |
LOG = FancyLogger(__name__)
|
16 |
|
|
|
76 |
metadata={"hnsw:space": "cosine"}
|
77 |
)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
async def summarize_chunk(self, chunk: CodeChunk) -> str:
|
80 |
"""Summarize a single code chunk using the LLM"""
|
81 |
prompt = f"""
|
|
|
111 |
)
|
112 |
|
113 |
# Get embedding for the summary
|
114 |
+
embedding = generate_embedding(summary, self.config.embedding)
|
115 |
|
116 |
# Store in ChromaDB
|
117 |
self.collection.add(
|
src/know_lang_bot/utils/model_provider.py
CHANGED
@@ -13,5 +13,8 @@ def create_pydantic_model(
|
|
13 |
return model_str
|
14 |
elif model_provider == ModelProvider.HUGGINGFACE:
|
15 |
return HuggingFaceModel(model_name=model_name)
|
|
|
|
|
|
|
16 |
else:
|
17 |
raise NotImplementedError(f"Model {model_provider}:{model_name} is not supported")
|
|
|
13 |
return model_str
|
14 |
elif model_provider == ModelProvider.HUGGINGFACE:
|
15 |
return HuggingFaceModel(model_name=model_name)
|
16 |
+
elif model_provider == ModelProvider.TESTING:
|
17 |
+
# should be used for testing purposes only
|
18 |
+
pass
|
19 |
else:
|
20 |
raise NotImplementedError(f"Model {model_provider}:{model_name} is not supported")
|
tests/test_summarizer.py
CHANGED
@@ -11,7 +11,7 @@ def config():
|
|
11 |
"""Create a test configuration"""
|
12 |
with tempfile.TemporaryDirectory() as temp_dir:
|
13 |
yield AppConfig(
|
14 |
-
llm={"model_name": "
|
15 |
db={"persist_directory": Path(temp_dir), "collection_name": "test_collection"}
|
16 |
)
|
17 |
|
@@ -85,11 +85,11 @@ def test_chromadb_initialization(mock_agent_class, config: AppConfig):
|
|
85 |
assert new_summarizer.collection is not None
|
86 |
|
87 |
@pytest.mark.asyncio
|
88 |
-
@patch('know_lang_bot.summarizer.summarizer.
|
89 |
@patch('know_lang_bot.summarizer.summarizer.Agent')
|
90 |
async def test_process_and_store_chunk_with_embedding(
|
91 |
mock_agent_class,
|
92 |
-
|
93 |
config: AppConfig,
|
94 |
sample_chunks: list[CodeChunk],
|
95 |
mock_run_result: Mock
|
@@ -100,8 +100,8 @@ async def test_process_and_store_chunk_with_embedding(
|
|
100 |
mock_agent.run = AsyncMock(return_value=mock_run_result)
|
101 |
|
102 |
# Setup mock embedding response
|
103 |
-
mock_embedding =
|
104 |
-
|
105 |
|
106 |
summarizer = CodeSummarizer(config)
|
107 |
|
@@ -112,9 +112,9 @@ async def test_process_and_store_chunk_with_embedding(
|
|
112 |
await summarizer.process_and_store_chunk(sample_chunks[0])
|
113 |
|
114 |
# Verify ollama.embed was called with correct parameters
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
)
|
119 |
|
120 |
# Verify collection.add was called with correct parameters
|
@@ -123,7 +123,7 @@ async def test_process_and_store_chunk_with_embedding(
|
|
123 |
|
124 |
kwargs = add_call[1]
|
125 |
assert len(kwargs['embeddings']) == 3
|
126 |
-
assert kwargs['embeddings'] == mock_embedding
|
127 |
assert kwargs['documents'][0] == mock_run_result.data
|
128 |
assert kwargs['ids'][0] == f"{sample_chunks[0].file_path}:{sample_chunks[0].start_line}-{sample_chunks[0].end_line}"
|
129 |
|
|
|
11 |
"""Create a test configuration"""
|
12 |
with tempfile.TemporaryDirectory() as temp_dir:
|
13 |
yield AppConfig(
|
14 |
+
llm={"model_name": "testing", "model_provider": "testing"},
|
15 |
db={"persist_directory": Path(temp_dir), "collection_name": "test_collection"}
|
16 |
)
|
17 |
|
|
|
85 |
assert new_summarizer.collection is not None
|
86 |
|
87 |
@pytest.mark.asyncio
|
88 |
+
@patch('know_lang_bot.summarizer.summarizer.generate_embedding')
|
89 |
@patch('know_lang_bot.summarizer.summarizer.Agent')
|
90 |
async def test_process_and_store_chunk_with_embedding(
|
91 |
mock_agent_class,
|
92 |
+
mock_embedding_generator,
|
93 |
config: AppConfig,
|
94 |
sample_chunks: list[CodeChunk],
|
95 |
mock_run_result: Mock
|
|
|
100 |
mock_agent.run = AsyncMock(return_value=mock_run_result)
|
101 |
|
102 |
# Setup mock embedding response
|
103 |
+
mock_embedding = [0.1, 0.2, 0.3] # Sample embedding vector
|
104 |
+
mock_embedding_generator.return_value = mock_embedding
|
105 |
|
106 |
summarizer = CodeSummarizer(config)
|
107 |
|
|
|
112 |
await summarizer.process_and_store_chunk(sample_chunks[0])
|
113 |
|
114 |
# Verify ollama.embed was called with correct parameters
|
115 |
+
mock_embedding_generator.assert_called_once_with(
|
116 |
+
mock_run_result.data,
|
117 |
+
config.embedding,
|
118 |
)
|
119 |
|
120 |
# Verify collection.add was called with correct parameters
|
|
|
123 |
|
124 |
kwargs = add_call[1]
|
125 |
assert len(kwargs['embeddings']) == 3
|
126 |
+
assert kwargs['embeddings'] == mock_embedding
|
127 |
assert kwargs['documents'][0] == mock_run_result.data
|
128 |
assert kwargs['ids'][0] == f"{sample_chunks[0].file_path}:{sample_chunks[0].start_line}-{sample_chunks[0].end_line}"
|
129 |
|