File size: 14,030 Bytes
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import json
import os
import re
from random import random
from pprint import pprint
import time
from typing import List, Optional, Union

from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda

from toolformers.base import Tool, StringParameter
from toolformers.sambanova.api_gateway import APIGateway

from toolformers.sambanova.utils import get_total_usage, usage_tracker


FUNCTION_CALLING_SYSTEM_PROMPT = """You have access to the following tools:

{tools}

You can call one or more tools by adding a <ToolCalls> section to your message. For example:
<ToolCalls>
```json
[{{
  "tool": <name of the selected tool>,
  "tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}]
```
</ToolCalls>

Note that you can select multiple tools at once by adding more objects to the list. Do not add \
multiple <ToolCalls> sections to the same message.
You will see the invocation of the tools in the response.


Think step by step
Do not call a tool if the input depends on another tool output that you do not have yet.
Do not try to answer until you get all the tools output, if you do not have an answer yet, you can continue calling tools until you do.
Your answer should be in the same language as the initial query.

"""  # noqa E501


conversational_response = Tool(
    name='ConversationalResponse',
    description='Respond conversationally only if no other tools should be called for a given query, or if you have a final answer. Response must be in the same language as the user query.',
    parameters=[StringParameter(name='response', description='Conversational response to the user. Must be in the same language as the user query.', required=True)],
    function=None
)


class FunctionCallingLlm:
    """
    function calling llm class
    """

    def __init__(
        self,
        tools: Optional[Union[Tool, List[Tool]]] = None,
        default_tool: Optional[Tool] = None,
        system_prompt: Optional[str] = None,
        prod_mode: bool = False,
        api: str = 'sncloud',
        coe: bool = False,
        do_sample: bool = False,
        max_tokens_to_generate: Optional[int] = None,
        temperature: float = 0.2,
        select_expert: Optional[str] = None,
    ) -> None:
        """
        Args:
            tools (Optional[Union[Tool, List[Tool]]]): The tools to use.
            default_tool (Optional[Tool]): The default tool to use.
                defaults to ConversationalResponse
            system_prompt (Optional[str]): The system prompt to use. defaults to FUNCTION_CALLING_SYSTEM_PROMPT
            prod_mode (bool): Whether to use production mode. Defaults to False.
            api (str): The api to use. Defaults to 'sncloud'.
            coe (bool): Whether to use coe. Defaults to False.
            do_sample (bool): Whether to do sample. Defaults to False.
            max_tokens_to_generate (Optional[int]): The max tokens to generate. If None, the model will attempt to use the maximum available tokens.
            temperature (float): The model temperature. Defaults to 0.2.
            select_expert (Optional[str]): The expert to use. Defaults to None.
        """
        self.prod_mode = prod_mode
        sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
        self.api = api
        self.llm = APIGateway.load_llm(
            type=api,
            streaming=True,
            coe=coe,
            do_sample=do_sample,
            max_tokens_to_generate=max_tokens_to_generate,
            temperature=temperature,
            select_expert=select_expert,
            process_prompt=False,
            sambanova_api_key=sambanova_api_key,
        )

        if isinstance(tools, Tool):
            tools = [tools]
        self.tools = tools
        if system_prompt is None:
            system_prompt = ''

        system_prompt = system_prompt.replace('{','{{').replace('}', '}}')
        
        if len(self.tools) > 0:
            system_prompt += '\n\n'
            system_prompt += FUNCTION_CALLING_SYSTEM_PROMPT
        self.system_prompt = system_prompt

        if default_tool is None:
            default_tool = conversational_response

    def execute(self, invoked_tools: List[dict]) -> tuple[bool, List[str]]:
        """
        Given a list of tool executions the llm return as required
        execute them given the name with the mane in tools_map and the input arguments
        if there is only one tool call and it is default conversational one, the response is marked as final response

        Args:
            invoked_tools (List[dict]): The list of tool executions generated by the LLM.
        """
        if self.tools is not None:
            tools_map = {tool.name.lower(): tool for tool in self.tools}
        else:
            tools_map = {}
        tool_msg = "Tool '{name}' response: {response}"
        tools_msgs = []
        if len(invoked_tools) == 1 and invoked_tools[0]['tool'].lower() == 'conversationalresponse':
            final_answer = True
            return final_answer, [invoked_tools[0]['tool_input']['response']]

        final_answer = False

        for tool in invoked_tools:
            if tool['tool'].lower() == 'invocationerror':
                tools_msgs.append(f'Tool invocation error: {tool["tool_input"]}')
            elif tool['tool'].lower() != 'conversationalresponse':
                print(f"\n\n---\nTool {tool['tool'].lower()} invoked with input {tool['tool_input']}\n")
                
                if tool['tool'].lower() not in tools_map:
                    tools_msgs.append(f'Tool {tool["tool"]} not found')
                else:
                    response = tools_map[tool['tool'].lower()].call_tool_for_toolformer(**tool['tool_input'])
                    # print(f'Tool response: {str(response)}\n---\n\n')
                    tools_msgs.append(tool_msg.format(name=tool['tool'], response=str(response)))
        return final_answer, tools_msgs

    def json_finder(self, input_string: str) -> Optional[str]:
        """
        find json structures in an LLM string response, if bad formatted using LLM to correct it

        Args:
            input_string (str): The string to find the json structure in.
        """

        # 1. Ideal pattern: correctly surrounded by <ToolCalls> tags
        json_pattern_1 = re.compile(r'<ToolCalls\>(.*)</ToolCalls\>', re.DOTALL + re.IGNORECASE)
        # 2. Sometimes the closing tag is missing
        json_pattern_2 = re.compile(r'<ToolCalls\>(.*)', re.DOTALL + re.IGNORECASE)
        # 3. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls>
        json_pattern_3 = re.compile(r'<ToolCall\>(.*)</ToolCall\>', re.DOTALL + re.IGNORECASE)
        # 4. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls> and the closing tag is missing
        json_pattern_4 = re.compile(r'<ToolCall\>(.*)', re.DOTALL + re.IGNORECASE)

        # Find the first JSON structure in the string
        json_match = json_pattern_1.search(input_string) or json_pattern_2.search(input_string) or json_pattern_3.search(input_string) or json_pattern_4.search(input_string)
        if json_match:
            json_str = json_match.group(1)

            # 1. Outermost list of JSON object
            call_pattern_1 = re.compile(r'\[.*\]', re.DOTALL)
            # 2. Outermost JSON object
            call_pattern_2 = re.compile(r'\{.*\}', re.DOTALL)

            call_match_1 = call_pattern_1.search(json_str)
            call_match_2 = call_pattern_2.search(json_str)

            if call_match_1:
                json_str = call_match_1.group(0)
                try:
                    return json.loads(json_str)
                except Exception as e:
                    return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
            elif call_match_2:
                json_str = call_match_2.group(0)
                try:
                    return [json.loads(json_str)]
                except Exception as e:
                    return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
            else:
                return [{'tool': 'InvocationError', 'tool_input' : 'Could not find JSON object in the <ToolCalls> section'}]
        else:
            dummy_json_response = [{'tool': 'ConversationalResponse', 'tool_input': {'response': input_string}}]
            json_str = dummy_json_response
        return json_str

    def msgs_to_llama3_str(self, msgs: list) -> str:
        """
        convert a list of langchain messages with roles to expected LLmana 3 input

        Args:
            msgs (list): The list of langchain messages.
        """
        formatted_msgs = []
        for msg in msgs:
            if msg.type == 'system':
                sys_placeholder = (
                    '<|begin_of_text|><|start_header_id|>system<|end_header_id|>system<|end_header_id|> {msg}'
                )
                formatted_msgs.append(sys_placeholder.format(msg=msg.content))
            elif msg.type == 'human':
                human_placeholder = '<|eot_id|><|start_header_id|>user<|end_header_id|>\nUser: {msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:'  # noqa E501
                formatted_msgs.append(human_placeholder.format(msg=msg.content))
            elif msg.type == 'ai':
                assistant_placeholder = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant: {msg}'
                formatted_msgs.append(assistant_placeholder.format(msg=msg.content))
            elif msg.type == 'tool':
                tool_placeholder = '<|eot_id|><|start_header_id|>tools<|end_header_id|>\n{msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:'  # noqa E501
                formatted_msgs.append(tool_placeholder.format(msg=msg.content))
            else:
                raise ValueError(f'Invalid message type: {msg.type}')
        return '\n'.join(formatted_msgs)

    def msgs_to_sncloud(self, msgs: list) -> list:
        """
        convert a list of langchain messages with roles to expected FastCoE input

        Args:
            msgs (list): The list of langchain messages.
        """
        formatted_msgs = []
        for msg in msgs:
            if msg.type == 'system':
                formatted_msgs.append({'role': 'system', 'content': msg.content})
            elif msg.type == 'human':
                formatted_msgs.append({'role': 'user', 'content': msg.content})
            elif msg.type == 'ai':
                formatted_msgs.append({'role': 'assistant', 'content': msg.content})
            elif msg.type == 'tool':
                formatted_msgs.append({'role': 'tools', 'content': msg.content})
            else:
                raise ValueError(f'Invalid message type: {msg.type}')
        return json.dumps(formatted_msgs)

    def function_call_llm(self, query: str, max_it: int = 10, debug: bool = False) -> str:
        """
        invocation method for function calling workflow

        Args:
            query (str): The query to execute.
            max_it (int, optional): The maximum number of iterations. Defaults to 5.
            debug (bool, optional): Whether to print debug information. Defaults to False.
        """
        function_calling_chat_template = ChatPromptTemplate.from_messages([('system', self.system_prompt)])
        tools_schemas = [tool.as_llama_schema() for tool in self.tools]

        history = function_calling_chat_template.format_prompt(tools=tools_schemas).to_messages()

        history.append(HumanMessage(query))
        tool_call_id = 0  # identification for each tool calling required to create ToolMessages
        with usage_tracker():

            for i in range(max_it):
                json_parsing_chain = RunnableLambda(self.json_finder)

                if self.api == 'sncloud':
                    prompt = self.msgs_to_sncloud(history)
                else:
                    prompt = self.msgs_to_llama3_str(history)
                # print(f'\n\n---\nCalling function calling LLM with prompt: \n{prompt}\n')
                
                exponential_backoff_lower = 30
                exponential_backoff_higher = 60
                llm_response = None
                for _ in range(5):
                    try:
                        llm_response = self.llm.invoke(prompt, stream_options={'include_usage': True})
                        break
                    except Exception as e:
                        if '429' in str(e):
                            print('Rate limit exceeded. Waiting with random exponential backoff.')
                            time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
                            exponential_backoff_lower *= 2
                            exponential_backoff_higher *= 2
                        else:
                            raise e

                print('LLM response:', llm_response)

                # print(f'\nFunction calling LLM response: \n{llm_response}\n---\n')
                parsed_tools_llm_response = json_parsing_chain.invoke(llm_response)

                history.append(AIMessage(llm_response))
                final_answer, tools_msgs = self.execute(parsed_tools_llm_response)
                if final_answer:  # if response was marked as final response in execution
                    final_response = tools_msgs[0]
                    if debug:
                        print('\n\n---\nFinal function calling LLM history: \n')
                        pprint(f'{history}')
                    return final_response, get_total_usage()
                else:
                    history.append(ToolMessage('\n'.join(tools_msgs), tool_call_id=tool_call_id))
                    tool_call_id += 1


        raise Exception('Not a final response yet', history)