gabykim commited on
Commit
c9b82b3
·
1 Parent(s): e5bfc68

voyageai code embedding support

Browse files
chromadb/transformers-voyage-voyage-code-3/8b1aa7a1-ab3a-481d-93f0-d3cfe1102024/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc2ad2936b97a745b3944f6e620b6b68e700f087b537d9d5fdf841e05289dbc0
3
+ size 21180000
chromadb/transformers-voyage-voyage-code-3/8b1aa7a1-ab3a-481d-93f0-d3cfe1102024/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9704261c6e5bfae182e06958ecff6199f03c9e1b13a6d73eb4c7034a7be4aaeb
3
+ size 100
chromadb/transformers-voyage-voyage-code-3/8b1aa7a1-ab3a-481d-93f0-d3cfe1102024/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cc112e3dd2dac5a0e2bb79cbb76f769a6e976706be58582bbdf420a7dd3b29b
3
+ size 590939
chromadb/transformers-voyage-voyage-code-3/8b1aa7a1-ab3a-481d-93f0-d3cfe1102024/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96b3e7659800f977eb5fb17017d037cfbcfde55284ef46f342b6b9973a770a55
3
+ size 20000
chromadb/transformers-voyage-voyage-code-3/8b1aa7a1-ab3a-481d-93f0-d3cfe1102024/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19fd5f1636c4e4ec69b581a68db57a148fcd2ae3339bbcacc16c58950718f212
3
+ size 42780
chromadb/transformers-voyage-voyage-code-3/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:952a740a945d42a643a84b55922ca78adaf773b3924a7fdb7f78dd2eb5e7f3c3
3
+ size 88666112
src/know_lang_bot/chat_bot/chat_graph.py CHANGED
@@ -14,7 +14,7 @@ from enum import Enum
14
  from rich.console import Console
15
  from know_lang_bot.utils.model_provider import create_pydantic_model
16
  from know_lang_bot.utils.chunking_util import truncate_chunk
17
- from know_lang_bot.models.embeddings import generate_embedding
18
  import voyageai
19
  from voyageai.object.reranking import RerankingObject
20
 
@@ -162,7 +162,8 @@ class RetrieveContextNode(BaseNode[ChatGraphState, ChatGraphDeps, ChatResult]):
162
  """Get initial chunks using embedding search"""
163
  question_embedding = generate_embedding(
164
  input=query,
165
- config=embedding_config
 
166
  )
167
 
168
  results = collection.query(
 
14
  from rich.console import Console
15
  from know_lang_bot.utils.model_provider import create_pydantic_model
16
  from know_lang_bot.utils.chunking_util import truncate_chunk
17
+ from know_lang_bot.models.embeddings import EmbeddingInputType, generate_embedding
18
  import voyageai
19
  from voyageai.object.reranking import RerankingObject
20
 
 
162
  """Get initial chunks using embedding search"""
163
  question_embedding = generate_embedding(
164
  input=query,
165
+ config=embedding_config,
166
+ input_type=EmbeddingInputType.QUERY
167
  )
168
 
169
  results = collection.query(
src/know_lang_bot/evaluation/embedding_evaluation.py CHANGED
@@ -10,7 +10,7 @@ from know_lang_bot.chat_bot.chat_graph import ChatResult
10
  from know_lang_bot.config import AppConfig, EmbeddingConfig
11
  import json
12
  from know_lang_bot.evaluation.chatbot_evaluation import EvalCase, TRANSFORMER_TEST_CASES
13
- from know_lang_bot.models.embeddings import generate_embedding, EmbeddingVector
14
 
15
  @dataclass
16
  class ConfigEvalResult:
@@ -65,7 +65,7 @@ async def analyze_embedding_distributions(
65
  # Generate embeddings for all test cases
66
  questions = [case.question for case in test_cases]
67
  try:
68
- embeddings = generate_embedding(questions, config.embedding)
69
 
70
  # Cache the embeddings
71
  cached_embeddings = {
 
10
  from know_lang_bot.config import AppConfig, EmbeddingConfig
11
  import json
12
  from know_lang_bot.evaluation.chatbot_evaluation import EvalCase, TRANSFORMER_TEST_CASES
13
+ from know_lang_bot.models.embeddings import EmbeddingInputType, generate_embedding, EmbeddingVector
14
 
15
  @dataclass
16
  class ConfigEvalResult:
 
65
  # Generate embeddings for all test cases
66
  questions = [case.question for case in test_cases]
67
  try:
68
+ embeddings = generate_embedding(questions, config.embedding, input_type=EmbeddingInputType.QUERY)
69
 
70
  # Cache the embeddings
71
  cached_embeddings = {
src/know_lang_bot/models/embeddings.py CHANGED
@@ -1,11 +1,20 @@
1
  import ollama
2
  import openai
 
 
3
  from know_lang_bot.config import EmbeddingConfig, ModelProvider
4
- from typing import Union, List, overload
 
5
 
6
  # Type definitions
7
  EmbeddingVector = List[float]
8
 
 
 
 
 
 
 
9
  def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
10
  """Helper function to process Ollama embeddings in batch."""
11
  return ollama.embed(model=model_name, input=inputs)['embeddings']
@@ -19,6 +28,12 @@ def _process_openai_batch(inputs: List[str], model_name: str) -> List[EmbeddingV
19
  )
20
  return [item.embedding for item in response.data]
21
 
 
 
 
 
 
 
22
  @overload
23
  def generate_embedding(input: str, config: EmbeddingConfig) -> EmbeddingVector: ...
24
 
@@ -27,7 +42,8 @@ def generate_embedding(input: List[str], config: EmbeddingConfig) -> List[Embedd
27
 
28
  def generate_embedding(
29
  input: Union[str, List[str]],
30
- config: EmbeddingConfig
 
31
  ) -> Union[EmbeddingVector, List[EmbeddingVector]]:
32
  """
33
  Generate embeddings for single text input or batch of texts.
@@ -54,8 +70,9 @@ def generate_embedding(
54
  if config.model_provider == ModelProvider.OLLAMA:
55
  embeddings = _process_ollama_batch(inputs, config.model_name)
56
  elif config.model_provider == ModelProvider.OPENAI:
57
- openai.api_key = config.api_key
58
  embeddings = _process_openai_batch(inputs, config.model_name)
 
 
59
  else:
60
  raise ValueError(f"Unsupported provider: {config.model_provider}")
61
 
 
1
  import ollama
2
  import openai
3
+ import voyageai
4
+ import voyageai.client
5
  from know_lang_bot.config import EmbeddingConfig, ModelProvider
6
+ from typing import Union, List, overload, Optional
7
+ from enum import Enum
8
 
9
  # Type definitions
10
  EmbeddingVector = List[float]
11
 
12
+
13
+ class EmbeddingInputType(Enum):
14
+ DOCUMENT = "document"
15
+ QUERY = "query"
16
+
17
+
18
  def _process_ollama_batch(inputs: List[str], model_name: str) -> List[EmbeddingVector]:
19
  """Helper function to process Ollama embeddings in batch."""
20
  return ollama.embed(model=model_name, input=inputs)['embeddings']
 
28
  )
29
  return [item.embedding for item in response.data]
30
 
31
+ def _process_voiage_batch(inputs: List[str], model_name: str, input_type:EmbeddingInputType) -> List[EmbeddingVector]:
32
+ """Helper function to process VoyageAI embeddings in batch."""
33
+ vo = voyageai.Client()
34
+ embeddings_obj = vo.embed(model=model_name, texts=inputs, input_type=input_type.value)
35
+ return embeddings_obj.embeddings
36
+
37
  @overload
38
  def generate_embedding(input: str, config: EmbeddingConfig) -> EmbeddingVector: ...
39
 
 
42
 
43
  def generate_embedding(
44
  input: Union[str, List[str]],
45
+ config: EmbeddingConfig,
46
+ input_type: Optional[EmbeddingInputType] = EmbeddingInputType.DOCUMENT
47
  ) -> Union[EmbeddingVector, List[EmbeddingVector]]:
48
  """
49
  Generate embeddings for single text input or batch of texts.
 
70
  if config.model_provider == ModelProvider.OLLAMA:
71
  embeddings = _process_ollama_batch(inputs, config.model_name)
72
  elif config.model_provider == ModelProvider.OPENAI:
 
73
  embeddings = _process_openai_batch(inputs, config.model_name)
74
+ elif config.model_provider == ModelProvider.VOYAGE:
75
+ embeddings = _process_voiage_batch(inputs, config.model_name, input_type)
76
  else:
77
  raise ValueError(f"Unsupported provider: {config.model_provider}")
78
 
src/know_lang_bot/utils/migration/{embedding_migrations.py → openai_embedding_migrations.py} RENAMED
File without changes
src/know_lang_bot/utils/migration/voyage_embedding_migraions.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import asyncio
3
+ import chromadb
4
+ from chromadb.errors import InvalidCollectionException
5
+ from rich.progress import Progress
6
+ from rich.console import Console
7
+ from typing import List
8
+ from know_lang_bot.config import AppConfig, EmbeddingConfig
9
+ from know_lang_bot.models.embeddings import generate_embedding, EmbeddingInputType
10
+ from know_lang_bot.utils.fancy_log import FancyLogger
11
+
12
+ LOG = FancyLogger(__name__)
13
+ console = Console()
14
+
15
+ BATCH_SIZE = 64 # VoyageAI's maximum batch size is 128
16
+
17
+ async def process_batch(
18
+ documents: List[str],
19
+ config: EmbeddingConfig,
20
+ ) -> List[List[float]]:
21
+ """Process a batch of documents to generate embeddings"""
22
+ try:
23
+ embeddings = generate_embedding(
24
+ input=documents,
25
+ config=config,
26
+ input_type=EmbeddingInputType.DOCUMENT
27
+ )
28
+ return embeddings
29
+ except Exception as e:
30
+ LOG.error(f"Error processing batch: {e}")
31
+ raise
32
+
33
+ async def migrate_embeddings(config: AppConfig):
34
+ """Migrate embeddings using VoyageAI's API"""
35
+ # Initialize source DB client (existing)
36
+ source_client = chromadb.PersistentClient(
37
+ path=str(config.db.persist_directory)
38
+ )
39
+ source_collection = source_client.get_collection(
40
+ name=config.db.collection_name
41
+ )
42
+
43
+ # Initialize target DB client (new)
44
+ target_path = Path(config.db.persist_directory).parent / f"transformers-{config.embedding.model_provider.value}-{config.embedding.model_name}"
45
+ target_path.mkdir(exist_ok=True)
46
+ target_client = chromadb.PersistentClient(path=str(target_path))
47
+
48
+ # Create new collection
49
+ new_collection_name = f"{config.db.collection_name}_voyage"
50
+ try:
51
+ target_collection = target_client.get_collection(name=new_collection_name)
52
+ console.print(f"[yellow]Collection {new_collection_name} already exists. Deleting...")
53
+ target_client.delete_collection(name=new_collection_name)
54
+ except InvalidCollectionException:
55
+ pass
56
+
57
+ target_collection = target_client.create_collection(
58
+ name=new_collection_name,
59
+ metadata={"hnsw:space": "cosine"}
60
+ )
61
+
62
+ # Get all documents from source
63
+ results = source_collection.get(
64
+ include=['documents', 'metadatas']
65
+ )
66
+
67
+ total_documents = len(results['ids'])
68
+ console.print(f"[green]Found {total_documents} documents to process")
69
+
70
+ with Progress() as progress:
71
+ batch_task = progress.add_task(
72
+ "Processing batches...",
73
+ total=total_documents
74
+ )
75
+
76
+ # Process in batches
77
+ for i in range(0, total_documents, BATCH_SIZE):
78
+ batch_end = min(i + BATCH_SIZE, total_documents)
79
+ batch_docs = results['documents'][i:batch_end]
80
+ batch_ids = results['ids'][i:batch_end]
81
+ batch_metadatas = results['metadatas'][i:batch_end]
82
+
83
+ try:
84
+ # Generate embeddings for batch
85
+ embeddings = await process_batch(
86
+ documents=batch_docs,
87
+ config=config.embedding
88
+ )
89
+
90
+ # Add to new collection
91
+ target_collection.add(
92
+ embeddings=embeddings,
93
+ documents=batch_docs,
94
+ metadatas=batch_metadatas,
95
+ ids=batch_ids
96
+ )
97
+
98
+ await asyncio.sleep(2)
99
+
100
+ except Exception as e:
101
+ LOG.error(f"Failed to process batch {i//BATCH_SIZE}: {e}")
102
+ # Log failed IDs for retry
103
+ failed_ids = batch_ids
104
+ console.print(f"[red]Failed IDs: {failed_ids}")
105
+ continue
106
+
107
+ finally:
108
+ progress.advance(batch_task, len(batch_docs))
109
+
110
+ # Print statistics
111
+ final_count = len(target_collection.get()['ids'])
112
+ console.print(f"\n[green]Migration complete!")
113
+ console.print(f"Source documents: {total_documents}")
114
+ console.print(f"Target documents: {final_count}")
115
+ console.print(f"\nNew database location: {target_path}")
116
+
117
+ if final_count < total_documents:
118
+ console.print(f"[yellow]Warning: {total_documents - final_count} documents failed to process")
119
+
120
+ if __name__ == "__main__":
121
+ config = AppConfig()
122
+ asyncio.run(migrate_embeddings(config))