File size: 3,861 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypedDict

from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams


class PromptManagementClient(TypedDict):
    prompt_id: str
    prompt_template: List[AllMessageValues]
    prompt_template_model: Optional[str]
    prompt_template_optional_params: Optional[Dict[str, Any]]
    completed_messages: Optional[List[AllMessageValues]]


class PromptManagementBase(ABC):

    @property
    @abstractmethod
    def integration_name(self) -> str:
        pass

    @abstractmethod
    def should_run_prompt_management(
        self,
        prompt_id: str,
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> bool:
        pass

    @abstractmethod
    def _compile_prompt_helper(
        self,
        prompt_id: str,
        prompt_variables: Optional[dict],
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> PromptManagementClient:
        pass

    def merge_messages(
        self,
        prompt_template: List[AllMessageValues],
        client_messages: List[AllMessageValues],
    ) -> List[AllMessageValues]:
        return prompt_template + client_messages

    def compile_prompt(
        self,
        prompt_id: str,
        prompt_variables: Optional[dict],
        client_messages: List[AllMessageValues],
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> PromptManagementClient:
        compiled_prompt_client = self._compile_prompt_helper(
            prompt_id=prompt_id,
            prompt_variables=prompt_variables,
            dynamic_callback_params=dynamic_callback_params,
        )

        try:
            messages = compiled_prompt_client["prompt_template"] + client_messages
        except Exception as e:
            raise ValueError(
                f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}"
            )

        compiled_prompt_client["completed_messages"] = messages
        return compiled_prompt_client

    def _get_model_from_prompt(
        self, prompt_management_client: PromptManagementClient, model: str
    ) -> str:
        if prompt_management_client["prompt_template_model"] is not None:
            return prompt_management_client["prompt_template_model"]
        else:
            return model.replace("{}/".format(self.integration_name), "")

    def get_chat_completion_prompt(
        self,
        model: str,
        messages: List[AllMessageValues],
        non_default_params: dict,
        prompt_id: str,
        prompt_variables: Optional[dict],
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> Tuple[
        str,
        List[AllMessageValues],
        dict,
    ]:
        if not self.should_run_prompt_management(
            prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
        ):
            return model, messages, non_default_params

        prompt_template = self.compile_prompt(
            prompt_id=prompt_id,
            prompt_variables=prompt_variables,
            client_messages=messages,
            dynamic_callback_params=dynamic_callback_params,
        )

        completed_messages = prompt_template["completed_messages"] or messages

        prompt_template_optional_params = (
            prompt_template["prompt_template_optional_params"] or {}
        )

        updated_non_default_params = {
            **non_default_params,
            **prompt_template_optional_params,
        }

        model = self._get_model_from_prompt(
            prompt_management_client=prompt_template, model=model
        )

        return model, completed_messages, updated_non_default_params