|
from openai import OpenAI
|
|
from app.config import get_settings
|
|
from loguru import logger
|
|
from typing import List, Dict, Any
|
|
import json
|
|
from http import HTTPStatus
|
|
from pathlib import PurePosixPath
|
|
import requests
|
|
from urllib.parse import urlparse, unquote
|
|
import random
|
|
|
|
from app.models.const import LANGUAGE_NAMES, Language
|
|
from app.exceptions import LLMResponseValidationError
|
|
import dashscope
|
|
|
|
from dashscope import ImageSynthesis
|
|
from app.schemas.llm import (
|
|
StoryGenerationRequest,
|
|
)
|
|
settings = get_settings()
|
|
|
|
|
|
openai_client = None
|
|
if settings.openai_api_key:
|
|
openai_client = OpenAI(api_key=settings.openai_api_key, base_url=settings.openai_base_url or "https://api.openai.com/v1")
|
|
aliyun_text_client = None
|
|
if settings.aliyun_api_key:
|
|
dashscope.api_key = settings.aliyun_api_key
|
|
aliyun_text_client = OpenAI(base_url=settings.aliyun_base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1", api_key=settings.aliyun_api_key)
|
|
if settings.deepseek_api_key:
|
|
deepseek_client = OpenAI(api_key=settings.deepseek_api_key, base_url=settings.deepseek_base_url or "https://api.deepseek.com/v1")
|
|
if settings.ollama_api_key:
|
|
ollama_client = OpenAI(api_key=settings.ollama_api_key, base_url=settings.ollama_base_url or "http://localhost:11434/v1")
|
|
if settings.siliconflow_api_key:
|
|
siliconflow_client = OpenAI(api_key=settings.siliconflow_api_key, base_url=settings.siliconflow_base_url or "https://api.siliconflow.cn/v1")
|
|
|
|
class LLMService:
|
|
def __init__(self):
|
|
self.openai_client = openai_client
|
|
self.aliyun_text_client = aliyun_text_client
|
|
self.text_llm_model = settings.text_llm_model
|
|
self.image_llm_model = settings.image_llm_model
|
|
|
|
async def generate_story(self, request: StoryGenerationRequest) -> List[Dict[str, Any]]:
|
|
"""生成故事场景
|
|
Args:
|
|
story_prompt (str, optional): 故事提示. Defaults to None.
|
|
segments (int, optional): 故事分段数. Defaults to 3.
|
|
|
|
Returns:
|
|
List[Dict[str, Any]]: 故事场景列表
|
|
"""
|
|
|
|
messages = [
|
|
{"role": "system", "content": "你是一个专业的故事创作者,善于创作引人入胜的故事。请只返回JSON格式的内容。"},
|
|
{"role": "user", "content": await self._get_story_prompt(request.story_prompt, request.language, request.segments)}
|
|
]
|
|
logger.info(f"prompt messages: {json.dumps(messages, indent=4, ensure_ascii=False)}")
|
|
response = await self._generate_response(text_llm_provider = request.text_llm_provider or None, text_llm_model = request.text_llm_model or None, messages=messages, response_format="json_object")
|
|
response = response["list"]
|
|
response = self.normalize_keys(response)
|
|
|
|
logger.info(f"Generated story: {json.dumps(response, indent=4, ensure_ascii=False)}")
|
|
|
|
self._validate_story_response(response)
|
|
|
|
return response
|
|
def normalize_keys(self, data):
|
|
"""
|
|
阿里云和 openai 的模型返回结果不一致,处理一下
|
|
修改对象中非 `text` 的键为 `image_prompt`
|
|
- 如果是字典,替换 `text` 以外的单个键为 `image_prompt`
|
|
- 如果是列表,对列表中的每个对象递归处理
|
|
"""
|
|
if isinstance(data, dict):
|
|
|
|
if "text" in data:
|
|
|
|
other_keys = [key for key in data.keys() if key != "text"]
|
|
|
|
if len(other_keys) == 1:
|
|
data["image_prompt"] = data.pop(other_keys[0])
|
|
elif len(other_keys) > 1:
|
|
raise ValueError(f"Unexpected extra keys: {other_keys}. Only one non-'text' key is allowed.")
|
|
return data
|
|
elif isinstance(data, list):
|
|
|
|
return [self.normalize_keys(item) for item in data]
|
|
else:
|
|
raise TypeError("Input must be a dict or list of dicts")
|
|
|
|
def generate_image(self, *, prompt: str, image_llm_provider: str = None, image_llm_model: str = None, resolution: str = "1024x1024") -> str:
|
|
|
|
"""生成图片
|
|
|
|
Args:
|
|
prompt (str): 图片描述
|
|
resolution (str): 图片分辨率,默认为 1024x1024
|
|
|
|
Returns:
|
|
str: 图片URL
|
|
"""
|
|
|
|
|
|
image_llm_provider = image_llm_provider or settings.image_provider
|
|
image_llm_model = image_llm_model or settings.image_llm_model
|
|
|
|
try:
|
|
|
|
safe_prompt = f"Create a safe, family-friendly illustration. {prompt} The image should be appropriate for all ages, non-violent, and non-controversial."
|
|
|
|
if image_llm_provider == "aliyun":
|
|
rsp = ImageSynthesis.call(model=image_llm_model,
|
|
prompt=prompt,
|
|
size=resolution,)
|
|
if rsp.status_code == HTTPStatus.OK:
|
|
|
|
for result in rsp.output.results:
|
|
return result.url
|
|
else:
|
|
error_message = f'Failed, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}'
|
|
logger.error(error_message)
|
|
raise Exception(error_message)
|
|
elif image_llm_provider == "openai":
|
|
if (resolution != None):
|
|
resolution = resolution.replace("*", "x")
|
|
response = self.openai_client.images.generate(
|
|
model=image_llm_model,
|
|
prompt=safe_prompt,
|
|
size=resolution,
|
|
quality="standard",
|
|
n=1
|
|
)
|
|
logger.info("image generate res", response.data[0].url)
|
|
return response.data[0].url
|
|
elif image_llm_provider == "siliconflow":
|
|
if (resolution != None):
|
|
resolution = resolution.replace("*", "x")
|
|
payload = {
|
|
"model": image_llm_model,
|
|
"prompt": safe_prompt,
|
|
"seed": random.randint(1000000, 4999999999),
|
|
"image_size": resolution,
|
|
"guidance_scale": 7.5,
|
|
"batch_size": 1,
|
|
}
|
|
headers = {
|
|
"Authorization": "Bearer " + settings.siliconflow_api_key,
|
|
"Content-Type": "application/json"
|
|
}
|
|
response = requests.request("POST", "https://api.siliconflow.cn/v1/images/generations", json=payload, headers=headers)
|
|
if response.text != None:
|
|
response = json.loads(response.text)
|
|
return response["images"][0]["url"]
|
|
else:
|
|
raise Exception(response.text)
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate image: {e}")
|
|
return ""
|
|
|
|
async def generate_story_with_images(self, request: StoryGenerationRequest) -> List[Dict[str, Any]]:
|
|
"""生成故事和配图
|
|
Args:
|
|
story_prompt (str, optional): 故事提示. Defaults to None.
|
|
language (Language, optional): 语言. Defaults to Language.CHINESE.
|
|
segments (int, optional): 故事分段数. Defaults to 3.
|
|
|
|
Returns:
|
|
List[Dict[str, Any]]: 故事场景列表,每个场景包含文本、图片提示词和图片URL
|
|
"""
|
|
|
|
story_segments = await self.generate_story(
|
|
request,
|
|
)
|
|
|
|
|
|
for segment in story_segments:
|
|
try:
|
|
image_url = self.generate_image(prompt=segment["image_prompt"], resolution=request.resolution, image_llm_provider=request.image_llm_provider, image_llm_model=request.image_llm_model)
|
|
segment["url"] = image_url
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate image for segment: {e}")
|
|
segment["url"] = None
|
|
|
|
return story_segments
|
|
|
|
def get_llm_providers(self) -> Dict[str, List[str]]:
|
|
imgLLMList = []
|
|
textLLMList = []
|
|
if settings.openai_api_key:
|
|
textLLMList.append("openai")
|
|
imgLLMList.append("openai")
|
|
if settings.aliyun_api_key:
|
|
textLLMList.append("aliyun")
|
|
imgLLMList.append("aliyun")
|
|
if settings.deepseek_api_key:
|
|
textLLMList.append("deepseek")
|
|
if settings.ollama_api_key:
|
|
textLLMList.append("ollama")
|
|
if settings.siliconflow_api_key:
|
|
textLLMList.append("siliconflow")
|
|
imgLLMList.append("siliconflow")
|
|
return { "textLLMProviders": textLLMList, "imageLLMProviders": imgLLMList }
|
|
|
|
def _validate_story_response(self, response: any) -> None:
|
|
"""验证故事生成响应
|
|
|
|
Args:
|
|
response: LLM 响应
|
|
|
|
Raises:
|
|
LLMResponseValidationError: 响应格式错误
|
|
"""
|
|
if not isinstance(response, list):
|
|
raise LLMResponseValidationError("Response must be an array")
|
|
|
|
for i, scene in enumerate(response):
|
|
if not isinstance(scene, dict):
|
|
raise LLMResponseValidationError(f"story item {i} must be an object")
|
|
|
|
if "text" not in scene:
|
|
raise LLMResponseValidationError(f"Scene {i} missing 'text' field")
|
|
|
|
if "image_prompt" not in scene:
|
|
raise LLMResponseValidationError(f"Scene {i} missing 'image_prompt' field")
|
|
|
|
if not isinstance(scene["text"], str):
|
|
raise LLMResponseValidationError(f"Scene {i} 'text' must be a string")
|
|
|
|
if not isinstance(scene["image_prompt"], str):
|
|
raise LLMResponseValidationError(f"Scene {i} 'image_prompt' must be a string")
|
|
|
|
async def _generate_response(self, *, text_llm_provider: str = None, text_llm_model: str = None, messages: List[Dict[str, str]], response_format: str = "json_object") -> any:
|
|
"""生成 LLM 响应
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
response_format: 响应格式,默认为 json_object
|
|
|
|
Returns:
|
|
Dict[str, Any]: 解析后的响应
|
|
|
|
Raises:
|
|
Exception: 请求失败或解析失败时抛出异常
|
|
"""
|
|
if text_llm_provider == None:
|
|
text_llm_provider = settings.text_llm_provider
|
|
if text_llm_provider == "aliyun":
|
|
text_client = self.aliyun_text_client
|
|
elif text_llm_provider == "openai":
|
|
text_client = self.openai_client
|
|
elif text_llm_provider == "deepseek":
|
|
text_client = deepseek_client
|
|
elif text_llm_provider == "ollama":
|
|
text_client = ollama_client
|
|
elif text_llm_provider == "siliconflow":
|
|
text_client = siliconflow_client
|
|
if text_llm_model == None:
|
|
text_llm_model = settings.text_llm_model
|
|
response = text_client.chat.completions.create(
|
|
model= text_llm_model,
|
|
response_format={"type": response_format},
|
|
messages=messages,
|
|
)
|
|
try:
|
|
content = response.choices[0].message.content
|
|
result = json.loads(content)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse response: {e}")
|
|
raise e
|
|
|
|
async def _get_story_prompt(self, story_prompt: str = None, language: Language = Language.CHINESE_CN, segments: int = 3) -> str:
|
|
"""生成故事提示词
|
|
|
|
Args:
|
|
story_prompt (str, optional): 故事提示. Defaults to None.
|
|
segments (int, optional): 故事分段数. Defaults to 3.
|
|
|
|
Returns:
|
|
str: 完整的提示词
|
|
"""
|
|
|
|
languageValue = LANGUAGE_NAMES[language]
|
|
if story_prompt:
|
|
base_prompt = f"讲一个故事,主题是:{story_prompt}"
|
|
|
|
return f"""
|
|
{base_prompt}. The story needs to be divided into {segments} scenes, and each scene must include descriptive text and an image prompt.
|
|
|
|
Please return the result in the following JSON format, where the key `list` contains an array of objects:
|
|
|
|
**Expected JSON format**:
|
|
{{
|
|
"list": [
|
|
{{
|
|
"text": "Descriptive text for the scene",
|
|
"image_prompt": "Detailed image generation prompt, described in English"
|
|
}},
|
|
{{
|
|
"text": "Another scene description text",
|
|
"image_prompt": "Another detailed image generation prompt in English"
|
|
}}
|
|
]
|
|
}}
|
|
|
|
**Requirements**:
|
|
1. The root object must contain a key named `list`, and its value must be an array of scene objects.
|
|
2. Each object in the `list` array must include:
|
|
- `text`: A descriptive text for the scene, written in {languageValue}.
|
|
- `image_prompt`: A detailed prompt for generating an image, written in English.
|
|
3. Ensure the JSON format matches the above example exactly. Avoid extra fields or incorrect key names like `cimage_prompt` or `inage_prompt`.
|
|
|
|
**Important**:
|
|
- If there is only one scene, the array under `list` should contain a single object.
|
|
- The output must be a valid JSON object. Do not include explanations, comments, or additional content outside the JSON.
|
|
|
|
Example output:
|
|
{{
|
|
"list": [
|
|
{{
|
|
"text": "Scene description text",
|
|
"image_prompt": "Detailed image generation prompt in English"
|
|
}}
|
|
]
|
|
}}
|
|
"""
|
|
|
|
|
|
|
|
|
|
llm_service = LLMService()
|
|
|