Spaces:
Runtime error
Runtime error
File size: 6,901 Bytes
7a919c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright (c) OpenMMLab. All rights reserved.
"""LLM client."""
import argparse
import json
import pytoml
import requests
from loguru import logger
class ChatClient:
"""A class to handle client-side interactions with a chat service.
This class is responsible for loading configurations from a given path,
building prompts, and generating responses by interacting with the chat
service.
"""
def __init__(self, config_path: str) -> None:
"""Initialize the ChatClient with the path of the configuration
file."""
self.config_path = config_path
def load_config(self):
"""Load the 'llm' section of the configuration from the provided
path."""
with open(self.config_path, encoding='utf8') as f:
config = pytoml.load(f)
return config['llm']
def load_llm_config(self):
"""Load the 'server' section of the 'llm' configuration from the
provided path."""
with open(self.config_path, encoding='utf8') as f:
config = pytoml.load(f)
return config['llm']['server']
def build_prompt(self,
history_pair,
instruction: str,
template: str,
context: str = '',
reject: str = '<reject>'):
"""Build a prompt for interaction.
Args:
history_pair (list): List of previous interactions.
instruction (str): Instruction for the current interaction.
template (str): Template for constructing the interaction.
context (str, optional): Context of the interaction. Defaults to ''. # noqa E501
reject (str, optional): Text that indicates a rejected interaction. Defaults to '<reject>'. # noqa E501
Returns:
tuple: A tuple containing the constructed instruction and real history.
"""
if context is not None and len(context) > 0:
instruction = template.format(context, instruction)
real_history = []
for pair in history_pair:
if pair[1] == reject:
continue
if pair[0] is None or pair[1] is None:
continue
if len(pair[0]) < 1 or len(pair[1]) < 1:
continue
real_history.append(pair)
return instruction, real_history
def generate_response(self, prompt, backend, history=[]):
"""Generate a response from the chat service.
Args:
prompt (str): The prompt to send to the chat service.
history (list, optional): List of previous interactions. Defaults to [].
backend (str, optional): Determine which LLM should be called. Default to `local`
Returns:
str: Generated response from the chat service.
"""
llm_config = self.load_config()
url, enable_local, enable_remote = (llm_config['client_url'],
llm_config['enable_local'],
llm_config['enable_remote'])
type_given = llm_config['server']['remote_type'] != "" # yyj
api_given = llm_config['server']['remote_api_key'] != "" # yyj
llm_given = llm_config['server']['remote_llm_model'] != "" # yyj
if backend == 'local' and enable_local: # yyj
max_length = llm_config['server']['local_llm_max_text_length'] # yyj
elif backend == 'remote' and enable_remote and type_given and api_given and llm_given: # yyj
max_length = llm_config['server']['remote_llm_max_text_length'] # yyj
backend = llm_config['server']['remote_type'] # yyj
else:
raise ValueError('Invalid backend or backend is not enabled')
# remote = False
# if backend != 'local':
# remote = True
# if remote and not enable_remote:
# # if use remote LLM (for example, kimi) and disable enable_remote
# # auto fixed to local LLM
# remote = False
# logger.warning(
# 'disable remote LLM while choose remote LLM, auto fixed')
# elif not enable_local and not remote:
# remote = True
# backend = 'remote' # yyj
# logger.warning(
# 'diable local LLM while using local LLM, auto fixed')
# if remote:
# if backend == 'remote':
# backend = llm_config['server']['remote_type']
# max_length = llm_config['server']['remote_llm_max_text_length']
# else:
# backend = 'local'
# max_length = llm_config['server']['local_llm_max_text_length']
if len(prompt) > max_length:
logger.warning(
f'prompt length {len(prompt)} > max_length {max_length}, truncated' # noqa E501
)
prompt = prompt[0:max_length]
try:
header = {'Content-Type': 'application/json'}
data_history = []
for item in history:
data_history.append([item[0], item[1]])
data = {
'prompt': prompt,
'history': data_history,
'backend': backend
}
resp = requests.post(url,
headers=header,
data=json.dumps(data),
timeout=300)
if resp.status_code != 200:
raise Exception(str((resp.status_code, resp.reason)))
return resp.json()['text']
except Exception as e:
logger.error(str(e))
logger.error(
'Do you forget `--standalone` when `python3 -m huixiangdou.main` ?' # noqa E501
)
return ''
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description='Client for hybrid llm service.')
parser.add_argument(
'--config_path',
default='config.ini',
help='Configuration path. Default value is config.ini' # noqa E501
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
client = ChatClient(config_path=args.config_path)
question = '“{}”\n请仔细阅读以上问题,提取其中的实体词,结果直接用 list 表示,不要解释。'.format(
'请问triviaqa 5shot结果怎么在summarizer里输出呢')
print(client.generate_response(prompt=question, backend='local'))
print(
client.generate_response(prompt='请问 ncnn 的全称是什么',
history=[('ncnn 是什么',
'ncnn中的n代表nihui,cnn代表卷积神经网络。')],
backend='remote'))
|