File size: 6,349 Bytes
1ef9436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : llm.py
@Time : 2024/09/02
@Project : https://github.com/PeterH0323/Streamer-Sales
@Author : HinGwenWong
@Version : 1.0
@Desc : 大模型接口
"""
from typing import Dict, List
from fastapi import APIRouter, Depends
from loguru import logger
from ..database.llm_db import get_llm_product_prompt_base_info
from ..database.product_db import get_db_product_info
from ..database.streamer_info_db import get_db_streamer_info
from ..models.product_model import ProductInfo
from ..models.streamer_info_model import StreamerInfo
from ..modules.agent.agent_worker import get_agent_result
from ..server_info import SERVER_PLUGINS_INFO
from ..utils import LLM_MODEL_HANDLER, ResultCode, make_return_data
from .users import get_current_user_info
router = APIRouter(
prefix="/llm",
tags=["llm"],
responses={404: {"description": "Not found"}},
)
def combine_history(prompt: list, history_msg: list):
"""生成对话历史 prompt
Args:
prompt (_type_): _description_
history_msg (_type_): _description_. Defaults to None.
Returns:
_type_: _description_
"""
# 角色映射表
role_map = {"streamer": "assistant", "user": "user"}
# 生成历史对话信息
for message in history_msg:
prompt.append({"role": role_map[message["role"]], "content": message["message"]})
return prompt
async def gen_poduct_base_prompt(
user_id: int,
streamer_id: int = -1,
product_id: int = -1,
streamer_info: StreamerInfo | None = None,
product_info: ProductInfo | None = None,
) -> List[Dict[str, str]]:
"""生成商品介绍的 prompt
Args:
user_id (int): 用户 ID
streamer_id (int): 主播 ID
product_id (int): 商品 ID
streamer_info (StreamerInfo, optional): 主播信息,如果为空则根据 streamer_id 查表
product_info (ProductInfo, optional): 商品信息,如果为空则根据 product_id 查表
Returns:
List[Dict[str,str]]: 生成的 promot
"""
assert (streamer_id == -1 and streamer_info is not None) or (streamer_id != -1 and streamer_info is None)
assert (product_id == -1 and product_info is not None) or (product_id != -1 and product_info is None)
# 加载对话配置文件
dataset_yaml = await get_llm_product_prompt_base_info()
# 从配置中提取对话设置相关的信息
# system_str: 系统词,针对销售角色定制
# first_input_template: 对话开始时的第一个输入模板
# product_info_struct_template: 产品信息结构模板
system = dataset_yaml["conversation_setting"]["system"]
first_input_template = dataset_yaml["conversation_setting"]["first_input"]
product_info_struct_template = dataset_yaml["product_info_struct"]
# 根据 ID 获取主播信息
if streamer_info is None:
streamer_info = await get_db_streamer_info(user_id, streamer_id)
streamer_info = streamer_info[0]
# 将销售角色名和角色信息插入到 system prompt
character_str = streamer_info.character.replace(";", "、")
system_str = system.replace("{role_type}", streamer_info.name).replace("{character}", character_str)
# 根据 ID 获取商品信息
if product_info is None:
product_list, _ = await get_db_product_info(user_id, product_id=product_id)
product_info = product_list[0]
heighlights_str = product_info.heighlights.replace(";", "、")
product_info_str = product_info_struct_template[0].replace("{name}", product_info.product_name)
product_info_str += product_info_struct_template[1].replace("{highlights}", heighlights_str)
# 生成商品文案 prompt
sales_doc_prompt = first_input_template.replace("{product_info}", product_info_str)
prompt = [{"role": "system", "content": system_str}, {"role": "user", "content": sales_doc_prompt}]
logger.info(prompt)
return prompt
async def get_agent_res(prompt, departure_place, delivery_company):
"""调用 Agent 能力"""
agent_response = ""
if not SERVER_PLUGINS_INFO.agent_enabled:
# 如果不开启则直接返回空
return ""
GENERATE_AGENT_TEMPLATE = (
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" # Agent prompt 模板
)
input_prompt = prompt[-1]["content"]
agent_response = get_agent_result(LLM_MODEL_HANDLER, input_prompt, departure_place, delivery_company)
if agent_response != "":
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt)
logger.info(f"Agent response: {agent_response}")
return agent_response
async def get_llm_res(prompt):
"""获取 LLM 推理返回
Args:
prompt (str): _description_
Returns:
_type_: _description_
"""
logger.info(prompt)
model_name = LLM_MODEL_HANDLER.available_models[0]
res_data = ""
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt):
res_data = item["choices"][0]["message"]["content"]
return res_data
@router.get("/gen_sales_doc", summary="生成主播文案接口")
async def get_product_info_api(streamer_id: int, product_id: int, user_id: int = Depends(get_current_user_info)):
"""生成口播文案
Args:
streamer_id (int): 主播 ID,用于获取性格等信息
product_id (int): 商品 ID
"""
prompt = await gen_poduct_base_prompt(user_id, streamer_id, product_id)
res_data = await get_llm_res(prompt)
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
@router.get("/gen_product_info")
async def get_product_info_api(product_id: int, user_id: int = Depends(get_current_user_info)):
"""TODO 根据说明书内容生成商品信息
Args:
gen_product_item (GenProductItem): _description_
"""
raise NotImplemented()
instruction_str = ""
prompt = [{"system": "现在你是一个文档小助手,你可以从文档里面总结出我需要的信息", "input": ""}]
res_data = ""
model_name = LLM_MODEL_HANDLER.available_models[0]
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt):
res_data += item
|