KingZack commited on
Commit
fddb754
·
1 Parent(s): 23f9974

adding Vectorization service

Browse files
src/ctp_slack_bot/services/VectorizationService.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, validator
2
+ from typing import List, Optional
3
+ from ctp_slack_bot.core.config import settings
4
+ import numpy as np
5
+ from openai import AsyncOpenAI # Updated import
6
+
7
+
8
+ class VectorizationService(BaseModel):
9
+ """
10
+ Service for vectorizing chunks of text data.
11
+ """
12
+ def __init__(self):
13
+ self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
14
+
15
+
16
+ async def get_embeddings(self, texts: List[str]) -> np.ndarray:
17
+ """Get embeddings for a list of texts using OpenAI's API.
18
+
19
+ Args:
20
+ texts (List[str]): List of text chunks to embed
21
+
22
+ Returns:
23
+ np.ndarray: Array of embeddings with shape (n_texts, VECTOR_DIMENSION)
24
+
25
+ Raises:
26
+ ValueError: If the embedding dimensions don't match expected size
27
+ """
28
+ try:
29
+ # Use the initialized client instead of the global openai module
30
+ response = await self.client.embeddings.create(
31
+ model=settings.EMBEDDING_MODEL,
32
+ input=texts,
33
+ encoding_format="float" # Ensure we get raw float values
34
+ )
35
+
36
+ # Extract embeddings and verify dimensions
37
+ embeddings = np.array([data.embedding for data in response.data])
38
+
39
+ if embeddings.shape[1] != settings.VECTOR_DIMENSION:
40
+ raise ValueError(
41
+ f"Embedding dimension mismatch. Expected {settings.VECTOR_DIMENSION}, "
42
+ f"but got {embeddings.shape[1]}. Please update VECTOR_DIMENSION "
43
+ f"in config.py to match the model's output."
44
+ )
45
+
46
+ return embeddings
47
+
48
+ except Exception as e:
49
+ print(f"Error getting embeddings: {str(e)}")
50
+ raise
51
+
52
+ def _test(self):
53
+ """
54
+ Test the vectorization service.
55
+ """
56
+
57
+ pass
58
+
59
+
60
+
61
+ vs = VectorizationService()
62
+ vs._test()
63
+
64
+