|
|
|
|
|
""" |
|
@File : llm.py |
|
@Time : 2024/09/02 |
|
@Project : https://github.com/PeterH0323/Streamer-Sales |
|
@Author : HinGwenWong |
|
@Version : 1.0 |
|
@Desc : 大模型接口 |
|
""" |
|
|
|
|
|
from typing import Dict, List |
|
|
|
from fastapi import APIRouter, Depends |
|
from loguru import logger |
|
|
|
from ..database.llm_db import get_llm_product_prompt_base_info |
|
from ..database.product_db import get_db_product_info |
|
from ..database.streamer_info_db import get_db_streamer_info |
|
from ..models.product_model import ProductInfo |
|
from ..models.streamer_info_model import StreamerInfo |
|
from ..modules.agent.agent_worker import get_agent_result |
|
from ..server_info import SERVER_PLUGINS_INFO |
|
from ..utils import LLM_MODEL_HANDLER, ResultCode, make_return_data |
|
from .users import get_current_user_info |
|
|
|
router = APIRouter( |
|
prefix="/llm", |
|
tags=["llm"], |
|
responses={404: {"description": "Not found"}}, |
|
) |
|
|
|
|
|
def combine_history(prompt: list, history_msg: list): |
|
"""生成对话历史 prompt |
|
|
|
Args: |
|
prompt (_type_): _description_ |
|
history_msg (_type_): _description_. Defaults to None. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
role_map = {"streamer": "assistant", "user": "user"} |
|
|
|
|
|
for message in history_msg: |
|
prompt.append({"role": role_map[message["role"]], "content": message["message"]}) |
|
|
|
return prompt |
|
|
|
|
|
async def gen_poduct_base_prompt( |
|
user_id: int, |
|
streamer_id: int = -1, |
|
product_id: int = -1, |
|
streamer_info: StreamerInfo | None = None, |
|
product_info: ProductInfo | None = None, |
|
) -> List[Dict[str, str]]: |
|
"""生成商品介绍的 prompt |
|
|
|
Args: |
|
user_id (int): 用户 ID |
|
streamer_id (int): 主播 ID |
|
product_id (int): 商品 ID |
|
streamer_info (StreamerInfo, optional): 主播信息,如果为空则根据 streamer_id 查表 |
|
product_info (ProductInfo, optional): 商品信息,如果为空则根据 product_id 查表 |
|
|
|
Returns: |
|
List[Dict[str,str]]: 生成的 promot |
|
""" |
|
|
|
assert (streamer_id == -1 and streamer_info is not None) or (streamer_id != -1 and streamer_info is None) |
|
assert (product_id == -1 and product_info is not None) or (product_id != -1 and product_info is None) |
|
|
|
|
|
dataset_yaml = await get_llm_product_prompt_base_info() |
|
|
|
|
|
|
|
|
|
|
|
system = dataset_yaml["conversation_setting"]["system"] |
|
first_input_template = dataset_yaml["conversation_setting"]["first_input"] |
|
product_info_struct_template = dataset_yaml["product_info_struct"] |
|
|
|
|
|
if streamer_info is None: |
|
streamer_info = await get_db_streamer_info(user_id, streamer_id) |
|
streamer_info = streamer_info[0] |
|
|
|
|
|
character_str = streamer_info.character.replace(";", "、") |
|
system_str = system.replace("{role_type}", streamer_info.name).replace("{character}", character_str) |
|
|
|
|
|
if product_info is None: |
|
product_list, _ = await get_db_product_info(user_id, product_id=product_id) |
|
product_info = product_list[0] |
|
|
|
heighlights_str = product_info.heighlights.replace(";", "、") |
|
product_info_str = product_info_struct_template[0].replace("{name}", product_info.product_name) |
|
product_info_str += product_info_struct_template[1].replace("{highlights}", heighlights_str) |
|
|
|
|
|
sales_doc_prompt = first_input_template.replace("{product_info}", product_info_str) |
|
|
|
prompt = [{"role": "system", "content": system_str}, {"role": "user", "content": sales_doc_prompt}] |
|
logger.info(prompt) |
|
|
|
return prompt |
|
|
|
|
|
async def get_agent_res(prompt, departure_place, delivery_company): |
|
"""调用 Agent 能力""" |
|
agent_response = "" |
|
|
|
if not SERVER_PLUGINS_INFO.agent_enabled: |
|
|
|
return "" |
|
|
|
GENERATE_AGENT_TEMPLATE = ( |
|
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" |
|
) |
|
input_prompt = prompt[-1]["content"] |
|
agent_response = get_agent_result(LLM_MODEL_HANDLER, input_prompt, departure_place, delivery_company) |
|
if agent_response != "": |
|
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt) |
|
logger.info(f"Agent response: {agent_response}") |
|
|
|
return agent_response |
|
|
|
|
|
async def get_llm_res(prompt): |
|
"""获取 LLM 推理返回 |
|
|
|
Args: |
|
prompt (str): _description_ |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
logger.info(prompt) |
|
model_name = LLM_MODEL_HANDLER.available_models[0] |
|
|
|
res_data = "" |
|
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt): |
|
res_data = item["choices"][0]["message"]["content"] |
|
|
|
return res_data |
|
|
|
|
|
@router.get("/gen_sales_doc", summary="生成主播文案接口") |
|
async def get_product_info_api(streamer_id: int, product_id: int, user_id: int = Depends(get_current_user_info)): |
|
"""生成口播文案 |
|
|
|
Args: |
|
streamer_id (int): 主播 ID,用于获取性格等信息 |
|
product_id (int): 商品 ID |
|
""" |
|
|
|
prompt = await gen_poduct_base_prompt(user_id, streamer_id, product_id) |
|
|
|
res_data = await get_llm_res(prompt) |
|
|
|
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data) |
|
|
|
|
|
@router.get("/gen_product_info") |
|
async def get_product_info_api(product_id: int, user_id: int = Depends(get_current_user_info)): |
|
"""TODO 根据说明书内容生成商品信息 |
|
|
|
Args: |
|
gen_product_item (GenProductItem): _description_ |
|
""" |
|
|
|
raise NotImplemented() |
|
instruction_str = "" |
|
prompt = [{"system": "现在你是一个文档小助手,你可以从文档里面总结出我需要的信息", "input": ""}] |
|
|
|
res_data = "" |
|
model_name = LLM_MODEL_HANDLER.available_models[0] |
|
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt): |
|
res_data += item |
|
|