File size: 12,429 Bytes
1ef9436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : streaming_room.py
@Time : 2024/08/31
@Project : https://github.com/PeterH0323/Streamer-Sales
@Author : HinGwenWong
@Version : 1.0
@Desc : 主播间信息交互接口
"""
import uuid
from copy import deepcopy
from pathlib import Path
import requests
from fastapi import APIRouter, Depends
from loguru import logger
from ...web_configs import API_CONFIG, WEB_CONFIGS
from ..database.product_db import get_db_product_info
from ..database.streamer_room_db import (
create_or_update_db_room_by_id,
get_live_room_info,
get_message_list,
update_db_room_status,
delete_room_id,
get_db_streaming_room_info,
update_message_info,
update_room_video_path,
)
from ..models.product_model import ProductInfo
from ..models.streamer_room_model import OnAirRoomStatusItem, RoomChatItem, SalesDocAndVideoInfo, StreamRoomInfo
from ..modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt
from ..routers.users import get_current_user_info
from ..server_info import SERVER_PLUGINS_INFO
from ..utils import ResultCode, make_return_data
from .digital_human import gen_tts_and_digital_human_video_app
from .llm import combine_history, gen_poduct_base_prompt, get_agent_res, get_llm_res
router = APIRouter(
prefix="/streaming-room",
tags=["streaming-room"],
responses={404: {"description": "Not found"}},
)
@router.get("/list", summary="获取所有直播间信息接口")
async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)):
"""获取所有直播间信息"""
# 加载直播间数据
streaming_room_list = await get_db_streaming_room_info(user_id)
for i in range(len(streaming_room_list)):
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
streaming_room_list[i] = dict(streaming_room_list[i])
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
@router.get("/info/{roomId}", summary="获取特定直播间信息接口")
async def get_streaming_room_id_api(
roomId: int, currentPage: int = 1, pageSize: int = 5, user_id: int = Depends(get_current_user_info)
):
"""获取特定直播间信息"""
# 加载直播间配置文件
assert roomId != 0
# TODO 加入分页
# 加载直播间数据
streaming_room_list = await get_db_streaming_room_info(user_id, room_id=roomId)
if len(streaming_room_list) == 1:
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
format_product_list = []
for db_product in streaming_room_list[0].product_list:
product_dict = dict(db_product)
# 将 start_video 改为服务器地址
if product_dict["start_video"] != "":
product_dict["start_video"] = API_CONFIG.REQUEST_FILES_URL + product_dict["start_video"]
format_product_list.append(product_dict)
streaming_room_list = dict(streaming_room_list[0])
streaming_room_list["product_list"] = format_product_list
else:
streaming_room_list = []
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
@router.get("/product-edit-list/{roomId}", summary="获取直播间商品编辑列表,含有已选中的标识")
async def get_streaming_room_product_list_api(
roomId: int, currentPage: int = 1, pageSize: int = 0, user_id: int = Depends(get_current_user_info)
):
"""获取直播间商品编辑列表,含有已选中的标识"""
# 获取目前直播间商品列表
if roomId == 0:
# 新的直播间
merge_list = []
exclude_list = []
else:
streaming_room_info = await get_db_streaming_room_info(user_id, roomId)
if len(streaming_room_info) == 0:
raise "401"
streaming_room_info = streaming_room_info[0]
# 获取未被选中的商品
exclude_list = [product.product_id for product in streaming_room_info.product_list]
merge_list = deepcopy(streaming_room_info.product_list)
# 获取未选中的商品信息
not_select_product_list, db_product_size = await get_db_product_info(user_id=user_id, exclude_list=exclude_list)
# 合并商品信息
for not_select_product in not_select_product_list:
merge_list.append(
SalesDocAndVideoInfo(
product_id=not_select_product.product_id,
product_info=ProductInfo(**dict(not_select_product)),
selected=False,
)
)
# TODO 懒加载分页
# 格式化
format_merge_list = []
for product in merge_list:
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
dict_info = dict(product)
if "stream_room" in dict_info:
dict_info.pop("stream_room")
format_merge_list.append(dict_info)
page_info = dict(
product_list=format_merge_list,
current=currentPage,
pageSize=db_product_size,
totalSize=db_product_size,
)
logger.info(page_info)
return make_return_data(True, ResultCode.SUCCESS, "成功", page_info)
@router.post("/create", summary="新增直播间接口")
async def streaming_room_edit_api(edit_item: dict, user_id: int = Depends(get_current_user_info)):
product_list = edit_item.pop("product_list")
status = edit_item.pop("status")
edit_item.pop("streamer_info")
edit_item.pop("room_id")
if "status_id" in edit_item:
edit_item.pop("status_id")
formate_product_list = []
for product in product_list:
if not product["selected"]:
continue
product.pop("product_info")
product_item = SalesDocAndVideoInfo(**product)
formate_product_list.append(product_item)
edit_item["user_id"] = user_id
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
room_id = create_or_update_db_room_by_id(0, formate_info, user_id)
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
@router.put("/edit/{room_id}", summary="编辑直播间接口")
async def streaming_room_edit_api(room_id: int, edit_item: dict, user_id: int = Depends(get_current_user_info)):
"""编辑直播间接口
Args:
edit_item (StreamRoomInfo): _description_
"""
product_list = edit_item.pop("product_list")
status = edit_item.pop("status")
edit_item.pop("streamer_info")
formate_product_list = []
for product in product_list:
if not product["selected"]:
continue
product.pop("product_info")
product_item = SalesDocAndVideoInfo(**product)
formate_product_list.append(product_item)
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
create_or_update_db_room_by_id(room_id, formate_info, user_id)
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
@router.delete("/delete/{roomId}", summary="删除直播间接口")
async def delete_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
process_success_flag = await delete_room_id(roomId, user_id)
if not process_success_flag:
return make_return_data(False, ResultCode.FAIL, "失败", "")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
# ============================================================
# 开播接口
# ============================================================
@router.post("/online/{roomId}", summary="直播间开播接口")
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
update_db_room_status(roomId, user_id, "online")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.put("/offline/{roomId}", summary="直播间下播接口")
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
update_db_room_status(roomId, user_id, "offline")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.post("/next-product/{roomId}", summary="直播间进行下一个商品讲解接口")
async def on_air_live_room_next_product_api(roomId: int, user_id: int = Depends(get_current_user_info)):
"""直播间进行下一个商品讲解
Args:
roomId (int): 直播间 ID
"""
update_db_room_status(roomId, user_id, "next-product")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.get("/live-info/{roomId}", summary="获取正在直播的直播间信息接口")
async def get_on_air_live_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
"""获取正在直播的直播间信息
1. 主播视频地址
2. 商品信息,显示在右下角的商品缩略图
3. 对话记录 conversation_list
Args:
roomId (int): 直播间 ID
"""
res_data = await get_live_room_info(user_id, roomId)
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
@router.put("/chat", summary="直播间对话接口")
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
# 根据直播间 ID 获取信息
streaming_room_info = await get_db_streaming_room_info(user_id, room_chat.roomId)
streaming_room_info = streaming_room_info[0]
# 商品索引
product_detail = streaming_room_info.product_list[streaming_room_info.status.current_product_index].product_info
# 销售 ID
sales_info_id = streaming_room_info.product_list[streaming_room_info.status.current_product_index].sales_info_id
# 更新对话记录
update_message_info(sales_info_id, user_id, role="user", message=room_chat.message)
# 获取最新的对话记录
conversation_list = get_message_list(sales_info_id)
# 根据对话记录生成 prompt
prompt = await gen_poduct_base_prompt(
user_id,
streamer_info=streaming_room_info.streamer_info,
product_info=product_detail,
) # system + 获取商品文案prompt
prompt = combine_history(prompt, conversation_list)
# ====================== Agent ======================
# 调取 Agent
agent_response = await get_agent_res(prompt, product_detail.departure_place, product_detail.delivery_company)
if agent_response != "":
logger.info("Agent 执行成功,不执行 RAG")
prompt[-1]["content"] = agent_response
# ====================== RAG ======================
# 调取 rag
elif SERVER_PLUGINS_INFO.rag_enabled:
logger.info("Agent 未执行 or 未开启,调用 RAG")
# agent 失败,调取 rag, chat_item.plugins.rag 为 True,则使用 RAG 查询数据库
rag_res = build_rag_prompt(RAG_RETRIEVER, product_detail.product_name, prompt[-1]["content"])
if rag_res != "":
prompt[-1]["content"] = rag_res
# 调取 LLM
streamer_res = await get_llm_res(prompt)
# 生成数字人视频
server_video_path = await gen_tts_and_digital_human_video_app(streaming_room_info.streamer_info.streamer_id, streamer_res)
# 更新直播间数字人视频信息
update_room_video_path(streaming_room_info.status_id, server_video_path)
# 更新对话记录
update_message_info(sales_info_id, streaming_room_info.streamer_info.streamer_id, role="streamer", message=streamer_res)
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.post("/asr", summary="直播间调取 ASR 语音转文字 接口")
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
# room_chat.asr_file 是 服务器地址,需要进行转换
asr_local_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(WEB_CONFIGS.ASR_FILE_DIR, Path(room_chat.asrFileUrl).name)
# 获取 ASR 结果
req_data = {
"user_id": user_id,
"request_id": str(uuid.uuid1()),
"wav_path": str(asr_local_path),
}
logger.info(req_data)
res = requests.post(API_CONFIG.ASR_URL, json=req_data).json()
asr_str = res["result"]
logger.info(f"ASR res = {asr_str}")
# 删除过程文件
asr_local_path.unlink()
return make_return_data(True, ResultCode.SUCCESS, "成功", asr_str)
|