File size: 1,943 Bytes
3cad23b
 
 
 
 
 
 
c07f594
 
 
 
 
 
 
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07f594
 
 
 
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import datetime
import os
from typing import List

from toolformers.base import Conversation, Toolformer, Tool
from toolformers.sambanova.function_calling import FunctionCallingLlm


COSTS = {
    'llama3-405b': {
        'prompt_tokens': 5e-6,
        'completion_tokens': 10e-6
    }
}

class SambanovaConversation(Conversation):
    def __init__(self, model_name, function_calling_llm : FunctionCallingLlm, category=None):
        self.model_name = model_name
        self.function_calling_llm = function_calling_llm
        self.category = category
    
    def chat(self, message, role='user', print_output=True):
        if role != 'user':
            raise ValueError('Role must be "user"')

        agent_id = os.environ.get('AGENT_ID', None)
        
        start_time = datetime.datetime.now()

        response, usage_data = self.function_calling_llm.function_call_llm(message)

        end_time = datetime.datetime.now()

        print('Usage data:', usage_data)
        if print_output:
            print(response)

        cost = 0

        for cost_name in ['prompt_tokens', 'completion_tokens']:
            cost += COSTS[self.model_name][cost_name] * usage_data[cost_name]
        
        return response

class SambanovaToolformer(Toolformer):
    def __init__(self, model_name: str):
        self.model_name = model_name

    def new_conversation(self, system_prompt: str, tools: List[Tool], category=None) -> SambanovaConversation:
        function_calling_llm = FunctionCallingLlm(system_prompt=system_prompt, tools=tools, select_expert=self.model_name)
        return SambanovaConversation(self.model_name, function_calling_llm, category)

#def make_llama_toolformer(model_name, system_prompt: str, tools: List[Tool]):
#    if model_name not in ['llama3-8b', 'llama3-70b', 'llama3-405b']:
#        raise ValueError(f"Unknown model name: {model_name}")
#
#    return SambanovaToolformer(model_name, system_prompt, tools)