File size: 2,921 Bytes
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from abc import ABC, abstractmethod
from typing import Any, List, Optional

from langchain.schema import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import SecretStr

from src.models.generator import ImgGenerator, QuestionGenerator
from src.models.lc_base_model import ChainGenerator, ContentGenerator, EvaluationChatModel
from src.models.lc_img_desc_model import EvaluationChatModelImg
from src.models.lc_qa_model import EvaluationChatModelQA


class ModelFactory(ABC):

    @abstractmethod
    def create_model(self, *args: Any, **kwargs: Any) -> ChainGenerator:
        """
        An abstract method to create a model, with the return types EvaluationChatModel or ChainGenerator.
        """


class EvaluationChatModelFactory(ModelFactory):

    def create_model(self, model_class: str, openai_api_key: SecretStr, **kwargs: Any) -> EvaluationChatModel:
        """
        Create a model based on the provided model class and OpenAI API key.

        Args:
            model_class (str): The type of model to create.
            openai_api_key (SecretStr): The API key for OpenAI.
            **kwargs (Any): Additional keyword arguments.

        Returns:
            EvaluationChatModel: The created evaluation chat model.

        Raises:
            ValueError: If an invalid model class is provided.
        """
        match model_class:
            case "qa":
                return EvaluationChatModelQA(openai_api_key=openai_api_key, **kwargs)
            case "img_desc":
                return EvaluationChatModelImg(openai_api_key=openai_api_key, **kwargs)
            case _:
                raise ValueError("Invalid model class provided")


class GeneratorModelFactory(ModelFactory):

    def create_model(
        self,
        model_class: str,
        openai_api_key: SecretStr,
        history_chat: Optional[List[HumanMessage | AIMessage]] = None,
        img_size: str = "256x256",
        **kwargs: Any,
    ) -> ContentGenerator:
        """
        Generate a model based on the specified model class and parameters.

        Parameters:
            model_class (str): The class of the model to create.
            openai_api_key (SecretStr): The API key for OpenAI.
            history_chat (Optional[list], optional): List of chat history. Defaults to None.
            img_size (str, optional): The size of the image. Defaults to "256x256".
            **kwargs (Any): Additional keyword arguments.

        Returns:
            ContentGenerator: A generator for the specified model class.
        """
        match model_class:
            case "qa":
                return QuestionGenerator(openai_api_key=openai_api_key, history_chat=history_chat or [], **kwargs)
            case "img_desc":
                return ImgGenerator(openai_api_key=openai_api_key, img_size=img_size, **kwargs)
            case _:
                raise ValueError("Invalid model class provided")