File size: 4,073 Bytes
f26e192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os

################################################################
# Format LLM messages
################################################################

def _format_messages(history, message=None, system=None, format='plain', 
        user_name='user', bot_name='assistant'):
    _history = []
    if format == 'openai_chat':
        if system:
            _history.append({'role': 'system', 'content': system})
        for human, ai in history:
            if human:
                _history.append({'role': 'user', 'content': human})
            if ai:
                _history.append({'role': 'assistant', 'content': ai})
        if message:
            _history.append({'role': 'user', 'content': message})
        return _history
    
    elif format == 'plain':
        if system:
            _history.append(system)
        for human, ai in history:
            if human:
                _history.append(f'{user_name}: {human}')
            if ai:
                _history.append(f'{bot_name}: {ai}')
        if message:
            _history.append(f'{user_name}: {message}')
            _history.append(f'{bot_name}: ')
        return '\n'.join(_history)
    
    else:
        raise ValueError(f"Invalid messages to format: {format}")

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

def _print_messages(history, message, bot_message, system=None,
    user_name='user', bot_name='assistant', format='plain', variant='primary', tag=None):
    """history is list of tuple [(user_msg, bot_msg), ...]"""
    prompt = _format_messages(history, message, system=system, user_name=user_name, bot_name=bot_name, format=format)
    bot_msg_color = {'primary': bcolors.OKGREEN, 'secondary': bcolors.HEADER, 
            'warning': bcolors.WARNING, 'error': bcolors.FAIL}.get(variant, bcolors.BOLD)
    tag = f'\n:: {tag}' if tag is not None else ''
    print(f'{bcolors.OKCYAN}{prompt}{bot_msg_color}{bot_message}{bcolors.WARNING}{tag}{bcolors.ENDC}')


################################################################
# LLM bot fn
################################################################

def _openai_bot_fn(message, history, **kwargs):
    _kwargs = dict(temperature=kwargs.get('temperature', 0))
    system = kwargs['system_prompt'] if 'system_prompt' in kwargs and kwargs['system_prompt'] else None
    chat_engine = kwargs.get('chat_engine', 'gpt-3.5-turbo')
    import openai
    openai.api_key = os.environ["OPENAI_API_KEY"]

    resp = openai.ChatCompletion.create(
        model=chat_engine,
        messages=_format_messages(history, message, system=system, format='openai_chat'),
        **_kwargs,
    )
    bot_message = resp.choices[0].message.content
    if 'verbose' in kwargs and kwargs['verbose']:
        _print_messages(history, message, bot_message, system=system, tag=f'openai ({chat_engine})')
    return bot_message

def _openai_stream_bot_fn(message, history, **kwargs):
    _kwargs = dict(temperature=kwargs.get('temperature', 0))
    system = kwargs['system_prompt'] if 'system_prompt' in kwargs and kwargs['system_prompt'] else None
    chat_engine = kwargs.get('chat_engine', 'gpt-3.5-turbo')
    import openai
    openai.api_key = os.environ["OPENAI_API_KEY"]

    resp = openai.ChatCompletion.create(
        model=chat_engine,
        messages=_format_messages(history, message, system=system, format='openai_chat'),
        stream=True,
        **_kwargs,
    )

    bot_message = ""
    for _resp in resp:
        if 'content' in _resp.choices[0].delta: # last resp delta is empty
            bot_message += _resp.choices[0].delta.content # need to accumulate previous message
        yield bot_message.strip() # accumulated message can easily be postprocessed
    if 'verbose' in kwargs and kwargs['verbose']:
        _print_messages(history, message, bot_message, system=system, tag=f'openai_stream ({chat_engine})')
    return bot_message