File size: 1,769 Bytes
9fd6e20
f7e11c1
bb7c9a3
 
9fd6e20
bb7c9a3
9fd6e20
bb7c9a3
 
9fd6e20
 
 
 
6853a4c
9fd6e20
bb7c9a3
6853a4c
9fd6e20
f7e11c1
9fd6e20
 
 
 
 
 
bb7c9a3
9fd6e20
 
 
 
 
6853a4c
bb7c9a3
9fd6e20
 
 
 
 
bb7c9a3
 
9fd6e20
 
bb7c9a3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from loguru import logger
from openai import AsyncOpenAI
from pydantic import ConfigDict
from typing import Any, Sequence, Self

from ctp_slack_bot.core import ApplicationComponentBase, Settings


class EmbeddingsModelService(ApplicationComponentBase):
    """
    Service for embeddings model operations.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)

    settings: Settings
    open_ai_client: AsyncOpenAI

    async def get_embeddings(self: Self, texts: Sequence[str]) -> Sequence[Sequence[float]]:
        """Get embeddings for a collection of texts using OpenAI’s API.

        Args:
            texts (Collection[str]): Collection of text chunks to embed
            
        Returns:
            NDArray: Array of embeddings with shape (n_texts, vector_dimension)
            
        Raises:
            ValueError: If the embedding dimensions don't match expected size
        """
        logger.debug("Creating embeddings for {} text string(s)…", len(texts))
        response = await self.open_ai_client.embeddings.create(
            model=self.settings.embedding_model,
            input=texts,
            encoding_format="float" # Ensure we get raw float values.
        )
        embeddings = tuple(tuple(data.embedding) for data in response.data)
        match embeddings:
            case (first, _) if len(first) != self.settings.vector_dimension:
                logger.error("Embedding dimension mismatch and/or misconfiguration: expected configured dimension {}, but got {}.", self.settings.vector_dimension, len(first))
                raise ValueError() # TODO: raise a more specific type.
        return embeddings

    @property
    def name(self: Self) -> str:
        return "embeddings_model_service"